tokenizer_manager.py 24.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""TokenizerManager is a process that tokenizes the text."""
17

Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
import asyncio
import dataclasses
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import json
21
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
22
import os
23
from typing import Dict, List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
import fastapi
Lianmin Zheng's avatar
Lianmin Zheng committed
26
27
28
import uvloop
import zmq
import zmq.asyncio
29
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
30

Lianmin Zheng's avatar
Lianmin Zheng committed
31
32
33
34
35
36
from sglang.srt.hf_transformers_utils import (
    get_config,
    get_context_length,
    get_processor,
    get_tokenizer,
)
37
38
39
40
from sglang.srt.managers.image_processor import (
    get_dummy_image_processor,
    get_image_processor,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
41
from sglang.srt.managers.io_struct import (
42
    AbortReq,
43
    BatchEmbeddingOut,
Lianmin Zheng's avatar
Lianmin Zheng committed
44
    BatchStrOut,
45
    BatchTokenIDOut,
46
    EmbeddingReqInput,
47
    FlushCacheReq,
Lianmin Zheng's avatar
Lianmin Zheng committed
48
    GenerateReqInput,
49
    ProfileReq,
50
    RewardReqInput,
51
    TokenizedEmbeddingReqInput,
Lianmin Zheng's avatar
Lianmin Zheng committed
52
    TokenizedGenerateReqInput,
53
    TokenizedRewardReqInput,
54
55
    UpdateWeightReqInput,
    UpdateWeightReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
56
)
57
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
58
from sglang.srt.server_args import PortArgs, ServerArgs
59
from sglang.srt.utils import is_generation_model, is_multimodal_model
Lianmin Zheng's avatar
Lianmin Zheng committed
60
61
62

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

63
64
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
65
66
67

@dataclasses.dataclass
class ReqState:
68
69
    """Store the state a request."""

Lianmin Zheng's avatar
Lianmin Zheng committed
70
71
72
73
74
75
    out_list: List
    finished: bool
    event: asyncio.Event


class TokenizerManager:
76
77
    """TokenizerManager is a process that tokenizes the text."""

Lianmin Zheng's avatar
Lianmin Zheng committed
78
79
80
81
82
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
Liangsheng Yin's avatar
Liangsheng Yin committed
83
84
        self.server_args = server_args

85
        # Init inter-process communication
Lianmin Zheng's avatar
Lianmin Zheng committed
86
87
        context = zmq.asyncio.Context(2)
        self.recv_from_detokenizer = context.socket(zmq.PULL)
88
        self.recv_from_detokenizer.bind(f"ipc://{port_args.tokenizer_ipc_name}")
Lianmin Zheng's avatar
Lianmin Zheng committed
89

90
        self.send_to_scheduler = context.socket(zmq.PUSH)
91
        self.send_to_scheduler.connect(f"ipc://{port_args.scheduler_input_ipc_name}")
Lianmin Zheng's avatar
Lianmin Zheng committed
92

93
        # Read model args
Lianmin Zheng's avatar
Lianmin Zheng committed
94
        self.model_path = server_args.model_path
95
        self.served_model_name = server_args.served_model_name
Lianmin Zheng's avatar
Lianmin Zheng committed
96
        self.hf_config = get_config(
Yuanhan Zhang's avatar
Yuanhan Zhang committed
97
98
            self.model_path,
            trust_remote_code=server_args.trust_remote_code,
Lianmin Zheng's avatar
Lianmin Zheng committed
99
            model_override_args=json.loads(server_args.json_model_override_args),
Lianmin Zheng's avatar
Lianmin Zheng committed
100
        )
101
102
103
        self.is_generation = is_generation_model(
            self.hf_config.architectures, self.server_args.is_embedding
        )
104
105
106
        self.context_len = server_args.context_length or get_context_length(
            self.hf_config
        )
107
108
        # Create image processor placeholder
        self.image_processor = get_dummy_image_processor()
Lianmin Zheng's avatar
Lianmin Zheng committed
109

110
        # Create tokenizer
111
112
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
Lianmin Zheng's avatar
Lianmin Zheng committed
113
        else:
114
            if is_multimodal_model(self.hf_config.architectures):
115
116
117
118
119
120
121
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
122

123
124
125
                # We want to parallelize the image pre-processing so we create an executor for it
                self.image_processor = get_image_processor(
                    self.hf_config, server_args, self.processor.image_processor
126
127
128
129
130
131
132
                )
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
133

134
        # Store states
Lianmin Zheng's avatar
Lianmin Zheng committed
135
        self.to_create_loop = True
136
        self.rid_to_state: Dict[str, ReqState] = {}
Lianmin Zheng's avatar
Lianmin Zheng committed
137

138
        # For update model weights
139
140
141
        self.model_update_lock = asyncio.Lock()
        self.model_update_result = None

142
    async def generate_request(
143
        self,
144
        obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
145
        request: Optional[fastapi.Request] = None,
146
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
147
        if self.to_create_loop:
148
            self.create_handle_loop()
Lianmin Zheng's avatar
Lianmin Zheng committed
149

150
        while self.model_update_lock.locked():
151
            await asyncio.sleep(0.001)
152

153
        obj.post_init()
154
155
        is_single = obj.is_single

156
157
158
159
160
161
        if is_single:
            async for response in self._handle_single_request(obj, request):
                yield response
        else:
            async for response in self._handle_batch_request(obj, request):
                yield response
162

163
    async def _send_single_request(
164
        self,
165
        obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
166
        index: Optional[int] = None,
167
        input_id_index: Optional[int] = None,
168
        is_cache_for_prefill: Optional[bool] = False,
169
    ):
yichuan~'s avatar
yichuan~ committed
170
        if not is_cache_for_prefill:  # The normal case with a single prompt
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
            if index is None:
                rid = obj.rid
                if hasattr(obj, "conv"):
                    # reward model
                    conv = obj.conv
                    input_text = self.tokenizer.apply_chat_template(
                        conv, tokenize=False
                    )
                    input_ids = self.tokenizer.encode(input_text)
                elif obj.input_ids is None:
                    input_text = obj.text
                    input_ids = self.tokenizer.encode(input_text)
                else:
                    input_text = obj.text if obj.text is not None else None
                    input_ids = obj.input_ids

                sampling_params = self._get_sampling_params(obj.sampling_params)
                if self.is_generation:
                    image_inputs = await self.image_processor.process_images_async(
                        obj.image_data, obj
                    )
                    return_logprob = obj.return_logprob
                    logprob_start_len = obj.logprob_start_len
                    top_logprobs_num = obj.top_logprobs_num
195
            else:
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
                rid = obj.rid[index]
                if hasattr(obj, "conv"):
                    # reward model
                    conv = obj.conv[index]
                    input_text = self.tokenizer.apply_chat_template(
                        conv, tokenize=False
                    )
                    input_ids = self.tokenizer.encode(input_text)
                elif obj.input_ids is None:
                    input_text = obj.text[input_id_index]
                    input_ids = self.tokenizer.encode(input_text)
                else:
                    input_text = (
                        obj.text[input_id_index] if obj.text is not None else None
                    )
                    input_ids = obj.input_ids[input_id_index]
Lianmin Zheng's avatar
Lianmin Zheng committed
212

213
214
215
216
217
218
219
220
                sampling_params = self._get_sampling_params(obj.sampling_params[index])
                if self.is_generation:
                    image_inputs = await self.image_processor.process_images_async(
                        obj.image_data[index], obj
                    )
                    return_logprob = obj.return_logprob[index]
                    logprob_start_len = obj.logprob_start_len[index]
                    top_logprobs_num = obj.top_logprobs_num[index]
221

222
            self._validate_input_length(input_ids)
223

yichuan~'s avatar
yichuan~ committed
224
        else:  # A prefill request to cache the common prompt for parallel sampling
225
            assert self.is_generation
yichuan~'s avatar
yichuan~ committed
226
227
            if obj.text is not None:
                if isinstance(obj.text, list):
228
                    input_text = obj.text[input_id_index]
yichuan~'s avatar
yichuan~ committed
229
230
231
232
                    rid = obj.rid[index]
                else:
                    input_text = obj.text
                    rid = obj.rid[0]
233
234
235
236
237
238
239
240
241
                if self.tokenizer is not None:
                    input_ids = self.tokenizer.encode(input_text)
                else:
                    assert obj.input_ids is not None
                    input_ids = obj.input_ids
                    if isinstance(obj.input_ids, list) and isinstance(
                        obj.input_ids[0], list
                    ):
                        # when obj["input_ids"] is List[List[int]]
242
                        input_ids = obj.input_ids[input_id_index]
243
244
245
246
                        rid = obj.rid[index]
                    else:
                        input_ids = obj.input_ids
                        rid = obj.rid[0]
247
            else:
yichuan~'s avatar
yichuan~ committed
248
249
250
251
252
                input_text = None
                if isinstance(obj.input_ids, list) and isinstance(
                    obj.input_ids[0], list
                ):
                    # when obj["input_ids"] is List[List[int]]
253
                    input_ids = obj.input_ids[input_id_index]
yichuan~'s avatar
yichuan~ committed
254
255
256
257
258
                    rid = obj.rid[index]
                else:
                    input_ids = obj.input_ids
                    rid = obj.rid[0]

259
260
            sampling_params = SamplingParams(**obj.sampling_params[0])
            sampling_params.max_new_tokens = 0
261
262
263
            image_inputs = await self.image_processor.process_images_async(
                obj.image_data[0], obj
            )
264
265
266
            return_logprob = obj.return_logprob[0]
            logprob_start_len = obj.logprob_start_len[0]
            top_logprobs_num = obj.top_logprobs_num[0]
Lianmin Zheng's avatar
Lianmin Zheng committed
267

268
        # Send to the controller
269
270
271
272
273
        if self.is_generation:
            tokenized_obj = TokenizedGenerateReqInput(
                rid,
                input_text,
                input_ids,
Liangsheng Yin's avatar
Liangsheng Yin committed
274
                image_inputs,
275
276
277
278
279
                sampling_params,
                return_logprob,
                logprob_start_len,
                top_logprobs_num,
                obj.stream,
280
                (
281
                    obj.lora_path[input_id_index]
282
283
284
                    if isinstance(obj.lora_path, list)
                    else obj.lora_path
                ),
285
            )
286
        elif isinstance(obj, EmbeddingReqInput):
287
288
289
290
291
292
            tokenized_obj = TokenizedEmbeddingReqInput(
                rid,
                input_text,
                input_ids,
                sampling_params,
            )
293
294
295
296
297
298
299
300
        else:
            assert isinstance(obj, RewardReqInput)
            tokenized_obj = TokenizedRewardReqInput(
                rid,
                input_text,
                input_ids,
                sampling_params,
            )
301

302
        self.send_to_scheduler.send_pyobj(tokenized_obj)
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
        return rid, input_ids

    async def _handle_single_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
        request: Optional[fastapi.Request] = None,
        index: Optional[int] = None,
        input_id_index: Optional[int] = None,
        is_cache_for_prefill: Optional[bool] = False,
    ):
        rid, input_ids = await self._send_single_request(
            obj,
            index,
            input_id_index=input_id_index,
            is_cache_for_prefill=is_cache_for_prefill,
        )
319

320
        # Recv results
321
322
323
        event = asyncio.Event()
        state = ReqState([], False, event)
        self.rid_to_state[rid] = state
324

325
        if not is_cache_for_prefill:
326
            async for response in self._wait_for_response(state, obj, rid, request):
327
                yield response
328
        else:
Ying Sheng's avatar
Ying Sheng committed
329
            assert self.is_generation
330
            await self._wait_for_cache_prefill_response(state, obj, rid, request)
331
            yield input_ids
332

333
    async def _handle_batch_request(
334
        self,
335
        obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
336
        request: Optional[fastapi.Request] = None,
337
    ):
338
        batch_size = obj.batch_size
339
340
341
342
        if self.is_generation:
            parallel_sample_num = obj.parallel_sample_num

            if parallel_sample_num != 1:
343
                # Send prefill requests to cache the common prefix
344
345
346
347
                parallel_sample_num += 1
                input_id_result = [] if obj.input_ids is None else None
                for i in range(batch_size):
                    async for input_id in self._handle_single_request(
348
349
350
351
352
                        obj,
                        request,
                        index=i,
                        input_id_index=i,
                        is_cache_for_prefill=True,
353
354
355
                    ):
                        if input_id_result is not None:
                            input_id_result.append(input_id)
356
                if input_id_result is not None:
357
358
359
                    obj.input_ids = input_id_result
        else:
            parallel_sample_num = 1
yichuan~'s avatar
yichuan~ committed
360

361
        # First send out all requests
362
        generators = []
363
364
365
        for i in range(batch_size):
            for j in range(parallel_sample_num):
                if j == 0 and parallel_sample_num != 1:
366
                    continue
367
368
                index = i * parallel_sample_num + j
                if parallel_sample_num != 1:
369
                    # Here when using parallel sampling we should consider prefill stage so the index is :  j + i * (parallel_sample_num-1) + batch_size - 1
370
                    index += batch_size - 1 - i
371

372
373
374
                rid, _ = await self._send_single_request(
                    obj, index, input_id_index=i, is_cache_for_prefill=False
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
375
376

                event = asyncio.Event()
377
                state = ReqState([], False, event)
Lianmin Zheng's avatar
Lianmin Zheng committed
378
                self.rid_to_state[rid] = state
379

380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
                generators.append(
                    self._wait_for_response(
                        state,
                        obj,
                        rid,
                        request,
                        index=index,
                        response_index=len(generators),
                    )
                )

        # Then process the responses based on streaming option
        is_stream = hasattr(obj, "stream") and obj.stream

        tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
395
        output_list = [None] * len(tasks)
396

397
        # Fetch results
398
399
400
401
        while tasks:
            done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)

            for task in done:
402
                cur_index = tasks.index(task)
403
404
405
406
407
408
409

                try:
                    result = task.result()

                    if is_stream:
                        yield result
                    else:
410
                        output_list[result["index"]] = result
411

412
413
                    tasks[cur_index] = asyncio.create_task(
                        generators[cur_index].__anext__()
Liangsheng Yin's avatar
Liangsheng Yin committed
414
                    )
415
                except StopAsyncIteration:
416
417
                    del generators[cur_index]
                    del tasks[cur_index]
418
419
420

        if not is_stream:
            yield output_list
421

422
    def _validate_input_length(self, input_ids: List[int]):
423
424
425
426
427
428
        if len(input_ids) >= self.context_len:
            raise ValueError(
                f"The input ({len(input_ids)} tokens) is longer than the "
                f"model's context length ({self.context_len} tokens)."
            )

429
    def _get_sampling_params(self, sampling_params_data: dict):
430
431
432
433
434
435
        sampling_params = SamplingParams(**sampling_params_data)
        if sampling_params.max_new_tokens != 0:
            sampling_params.normalize(self.tokenizer)
            sampling_params.verify()
        return sampling_params

436
437
438
    async def _wait_for_response(
        self,
        state: ReqState,
439
        obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
440
        rid: str,
441
442
        request: Optional[fastapi.Request] = None,
        index: Optional[int] = None,
443
        response_index: int = 0,
444
    ):
445
446
        while True:
            try:
447
                await asyncio.wait_for(state.event.wait(), timeout=4)
448
449
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
450
451
                    for rid in [obj.rid] if obj.is_single else obj.rid:
                        self.abort_request(rid)
452
453
454
                    raise ValueError(f"Abort request {rid}")
                continue

455
456
457
            if self.is_generation:
                out = self.convert_logprob_style(
                    state.out_list[-1],
458
459
460
461
462
463
                    obj.return_logprob if index is None else obj.return_logprob[index],
                    (
                        obj.top_logprobs_num
                        if index is None
                        else obj.top_logprobs_num[index]
                    ),
464
465
                    obj.return_text_in_logprobs,
                )
466
            else:  # isinstance(obj, (EmbeddingReqInput, RewardReqInput))
467
                out = state.out_list[-1]
468

469
470
            out["index"] = response_index

Ying Sheng's avatar
Ying Sheng committed
471
            # Log requests
472
            if self.server_args.log_requests and state.finished:
473
                logger.info(f"in={obj}, out={out}")
474
475
476
477
478
479
480

            state.out_list = []
            if state.finished:
                del self.rid_to_state[rid]
                yield out
                break

481
            state.event.clear()
482
483
            yield out

484
485
486
487
488
    async def _wait_for_cache_prefill_response(
        self,
        state: ReqState,
        obj: GenerateReqInput,
        rid: str,
489
        request: Optional[fastapi.Request] = None,
490
    ):
491
492
493
494
495
496
497
498
499
500
501
502
503
        while True:
            try:
                await asyncio.wait_for(state.event.wait(), timeout=4)
                break
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
                    for rid in obj.rid:
                        self.abort_request(rid)
                    raise ValueError(f"Abort request {rid}")
                continue

        assert state.finished
        del self.rid_to_state[rid]
Lianmin Zheng's avatar
Lianmin Zheng committed
504

505
506
    def flush_cache(self):
        req = FlushCacheReq()
507
        self.send_to_scheduler.send_pyobj(req)
Liangsheng Yin's avatar
Liangsheng Yin committed
508

509
510
511
512
513
    def abort_request(self, rid: str):
        if rid not in self.rid_to_state:
            return
        del self.rid_to_state[rid]
        req = AbortReq(rid)
514
        self.send_to_scheduler.send_pyobj(req)
515

516
517
518
519
520
521
522
523
    def start_profile(self):
        req = ProfileReq.START_PROFILE
        self.send_to_scheduler.send_pyobj(req)

    def stop_profile(self):
        req = ProfileReq.STOP_PROFILE
        self.send_to_scheduler.send_pyobj(req)

524
525
526
    async def update_weights(
        self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
    ):
527
528
529
530
531
532
533
534
535
536
537
        if self.to_create_loop:
            self.create_handle_loop()

        # default the load format to the server_args
        if obj.load_format is None:
            obj.load_format = self.server_args.load_format

        if not self.model_update_lock.locked():
            async with self.model_update_lock:
                # wait for the previous generation requests to finish
                while len(self.rid_to_state) > 0:
538
                    await asyncio.sleep(0.001)
539
                self.send_to_scheduler.send_pyobj(obj)
540
541
542
543
544
545
546
547
548
549
                self.model_update_result = asyncio.Future()
                result = await self.model_update_result
                if result.success:
                    self.server_args.model_path = obj.model_path
                    self.server_args.load_format = obj.load_format
                    self.model_path = obj.model_path
            return result.success, result.message
        else:
            return False, "Another update is in progress. Please try again later."

Lianmin Zheng's avatar
Lianmin Zheng committed
550
    def create_abort_task(self, obj: GenerateReqInput):
551
552
553
554
555
556
        # Abort the request if the client is disconnected.
        async def abort_request():
            await asyncio.sleep(3)
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
557
                for rid in obj.rid:
558
559
560
561
562
563
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

564
    def create_handle_loop(self):
565
566
567
        if not self.to_create_loop:
            return

Lianmin Zheng's avatar
Lianmin Zheng committed
568
569
570
571
572
        self.to_create_loop = False
        loop = asyncio.get_event_loop()
        loop.create_task(self.handle_loop())

    async def handle_loop(self):
573
574
        """The event loop that handles requests"""

Lianmin Zheng's avatar
Lianmin Zheng committed
575
        while True:
576
577
578
579
580
581
582
583
            recv_obj: Union[
                BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
            ] = await self.recv_from_detokenizer.recv_pyobj()

            if isinstance(recv_obj, UpdateWeightReqOutput):
                self.model_update_result.set_result(recv_obj)
                continue

584
585
586
            assert isinstance(
                recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
            ), f"Unexpected obj received: {type(recv_obj)}"
587

588
589
590
591
592
593
            for i, rid in enumerate(recv_obj.rids):
                state = self.rid_to_state.get(rid, None)
                if state is None:
                    continue

                recv_obj.meta_info[i]["id"] = rid
594
595
596
597
598
                if isinstance(recv_obj, BatchStrOut):
                    out_dict = {
                        "text": recv_obj.output_strs[i],
                        "meta_info": recv_obj.meta_info[i],
                    }
599
600
601
602
603
604
605
606
607
                elif isinstance(recv_obj, BatchTokenIDOut):
                    read_start = 0 if i == 0 else recv_obj.read_offsets[i - 1]
                    out_dict = {
                        "token_ids": recv_obj.decode_ids[
                            read_start : recv_obj.read_offsets[i]
                        ],
                        "meta_info": recv_obj.meta_info[i],
                    }

608
609
610
611
612
613
                else:
                    assert isinstance(recv_obj, BatchEmbeddingOut)
                    out_dict = {
                        "embedding": recv_obj.embeddings[i],
                        "meta_info": recv_obj.meta_info[i],
                    }
614
615
616
                state.out_list.append(out_dict)
                state.finished = recv_obj.finished_reason[i] is not None
                state.event.set()
617

Liangsheng Yin's avatar
Liangsheng Yin committed
618
    def convert_logprob_style(
619
620
621
622
623
        self,
        ret: dict,
        return_logprob: bool,
        top_logprobs_num: int,
        return_text_in_logprobs: bool,
Liangsheng Yin's avatar
Liangsheng Yin committed
624
    ):
625
        if return_logprob:
626
627
            ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
                ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
628
            )
629
630
            ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
                ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
631
            )
632
633

            if top_logprobs_num > 0:
634
                ret["meta_info"]["input_top_logprobs"] = (
zhyncs's avatar
zhyncs committed
635
                    self.detokenize_top_logprobs_tokens(
636
                        ret["meta_info"]["input_top_logprobs"],
zhyncs's avatar
zhyncs committed
637
638
                        return_text_in_logprobs,
                    )
639
                )
640
                ret["meta_info"]["output_top_logprobs"] = (
zhyncs's avatar
zhyncs committed
641
                    self.detokenize_top_logprobs_tokens(
642
                        ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
zhyncs's avatar
zhyncs committed
643
                    )
644
                )
645
646
        return ret

647
648
649
    def detokenize_logprob_tokens(
        self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
    ):
650
        # TODO(lianmin): This should run on DetokenizerManager
651
652
653
        if not decode_to_text:
            return [(logprob, token_id, None) for logprob, token_id in token_logprobs]

654
        assert self.tokenizer is not None
655
656
657
658
659
660
661
        token_ids = [tid for _, tid in token_logprobs]
        token_texts = self.tokenizer.batch_decode(token_ids)
        return [
            (logprob, token_id, token_text)
            for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
        ]

662
    def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
663
664
665
666
667
668
669
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
        for i, token_top_logprobs in enumerate(top_logprobs):
            if token_top_logprobs:
                top_logprobs[i] = self.detokenize_logprob_tokens(
                    token_top_logprobs, decode_to_text
                )
670
        return top_logprobs