tokenizer_manager.py 25.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""TokenizerManager is a process that tokenizes the text."""
17

Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
import asyncio
import dataclasses
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import json
21
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
22
import os
23
from typing import Dict, List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
import fastapi
Lianmin Zheng's avatar
Lianmin Zheng committed
26
27
28
import uvloop
import zmq
import zmq.asyncio
29
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
30

Lianmin Zheng's avatar
Lianmin Zheng committed
31
32
33
34
35
36
from sglang.srt.hf_transformers_utils import (
    get_config,
    get_context_length,
    get_processor,
    get_tokenizer,
)
37
38
39
40
from sglang.srt.managers.image_processor import (
    get_dummy_image_processor,
    get_image_processor,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
41
from sglang.srt.managers.io_struct import (
42
    AbortReq,
43
    BatchEmbeddingOut,
Lianmin Zheng's avatar
Lianmin Zheng committed
44
    BatchStrOut,
45
    BatchTokenIDOut,
46
    EmbeddingReqInput,
47
    FlushCacheReq,
Lianmin Zheng's avatar
Lianmin Zheng committed
48
    GenerateReqInput,
49
50
    GetMemPoolSizeReq,
    GetMemPoolSizeReqOutput,
51
    ProfileReq,
52
    RewardReqInput,
53
    TokenizedEmbeddingReqInput,
Lianmin Zheng's avatar
Lianmin Zheng committed
54
    TokenizedGenerateReqInput,
55
    TokenizedRewardReqInput,
56
57
    UpdateWeightReqInput,
    UpdateWeightReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
58
)
59
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
60
from sglang.srt.server_args import PortArgs, ServerArgs
61
from sglang.srt.utils import get_zmq_socket, is_generation_model, is_multimodal_model
Lianmin Zheng's avatar
Lianmin Zheng committed
62
63
64

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

65
66
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
67
68
69

@dataclasses.dataclass
class ReqState:
70
71
    """Store the state a request."""

Lianmin Zheng's avatar
Lianmin Zheng committed
72
73
74
75
76
77
    out_list: List
    finished: bool
    event: asyncio.Event


class TokenizerManager:
78
79
    """TokenizerManager is a process that tokenizes the text."""

Lianmin Zheng's avatar
Lianmin Zheng committed
80
81
82
83
84
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
Liangsheng Yin's avatar
Liangsheng Yin committed
85
86
        self.server_args = server_args

87
        # Init inter-process communication
Lianmin Zheng's avatar
Lianmin Zheng committed
88
        context = zmq.asyncio.Context(2)
89
90
91
92
93
94
        self.recv_from_detokenizer = get_zmq_socket(
            context, zmq.PULL, port_args.tokenizer_ipc_name
        )
        self.send_to_scheduler = get_zmq_socket(
            context, zmq.PUSH, port_args.scheduler_input_ipc_name
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
95

96
        # Read model args
Lianmin Zheng's avatar
Lianmin Zheng committed
97
        self.model_path = server_args.model_path
98
        self.served_model_name = server_args.served_model_name
Lianmin Zheng's avatar
Lianmin Zheng committed
99
        self.hf_config = get_config(
Yuanhan Zhang's avatar
Yuanhan Zhang committed
100
101
            self.model_path,
            trust_remote_code=server_args.trust_remote_code,
Lianmin Zheng's avatar
Lianmin Zheng committed
102
            model_override_args=json.loads(server_args.json_model_override_args),
Lianmin Zheng's avatar
Lianmin Zheng committed
103
        )
104
105
106
        self.is_generation = is_generation_model(
            self.hf_config.architectures, self.server_args.is_embedding
        )
107
108
109
        self.context_len = server_args.context_length or get_context_length(
            self.hf_config
        )
110
111
        # Create image processor placeholder
        self.image_processor = get_dummy_image_processor()
Lianmin Zheng's avatar
Lianmin Zheng committed
112

113
        # Create tokenizer
114
115
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
Lianmin Zheng's avatar
Lianmin Zheng committed
116
        else:
117
            if is_multimodal_model(self.hf_config.architectures):
118
119
120
121
122
123
124
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
125

126
127
                # We want to parallelize the image pre-processing so we create an executor for it
                self.image_processor = get_image_processor(
128
                    self.hf_config, server_args, self.processor
129
130
131
132
133
134
135
                )
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
136

137
        # Store states
Lianmin Zheng's avatar
Lianmin Zheng committed
138
        self.to_create_loop = True
139
        self.rid_to_state: Dict[str, ReqState] = {}
Lianmin Zheng's avatar
Lianmin Zheng committed
140

141
        # For update model weights
142
143
144
        self.model_update_lock = asyncio.Lock()
        self.model_update_result = None

145
    async def generate_request(
146
        self,
147
        obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
148
        request: Optional[fastapi.Request] = None,
149
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
150
        if self.to_create_loop:
151
            self.create_handle_loop()
Lianmin Zheng's avatar
Lianmin Zheng committed
152

153
        while self.model_update_lock.locked():
154
            await asyncio.sleep(0.001)
155

156
157
158
159
160
        if isinstance(obj, EmbeddingReqInput) and self.is_generation:
            raise ValueError(
                "This model does not appear to be an embedding model by default. Please add `--is-embedding` when launching the server or try another model."
            )

161
        obj.post_init()
162
        is_single = obj.is_single
163
164
165
166
167
168
        if is_single:
            async for response in self._handle_single_request(obj, request):
                yield response
        else:
            async for response in self._handle_batch_request(obj, request):
                yield response
169

170
    async def _send_single_request(
171
        self,
172
        obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
173
        index: Optional[int] = None,
174
        input_id_index: Optional[int] = None,
175
        is_cache_for_prefill: Optional[bool] = False,
176
    ):
yichuan~'s avatar
yichuan~ committed
177
        if not is_cache_for_prefill:  # The normal case with a single prompt
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
            if index is None:
                rid = obj.rid
                if hasattr(obj, "conv"):
                    # reward model
                    conv = obj.conv
                    input_text = self.tokenizer.apply_chat_template(
                        conv, tokenize=False
                    )
                    input_ids = self.tokenizer.encode(input_text)
                elif obj.input_ids is None:
                    input_text = obj.text
                    input_ids = self.tokenizer.encode(input_text)
                else:
                    input_text = obj.text if obj.text is not None else None
                    input_ids = obj.input_ids

                sampling_params = self._get_sampling_params(obj.sampling_params)
                if self.is_generation:
                    image_inputs = await self.image_processor.process_images_async(
197
                        obj.image_data, input_text or input_ids, obj
198
                    )
199
200
                    if image_inputs and "input_ids" in image_inputs:
                        input_ids = image_inputs["input_ids"]
201
202
203
                    return_logprob = obj.return_logprob
                    logprob_start_len = obj.logprob_start_len
                    top_logprobs_num = obj.top_logprobs_num
204
            else:
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
                rid = obj.rid[index]
                if hasattr(obj, "conv"):
                    # reward model
                    conv = obj.conv[index]
                    input_text = self.tokenizer.apply_chat_template(
                        conv, tokenize=False
                    )
                    input_ids = self.tokenizer.encode(input_text)
                elif obj.input_ids is None:
                    input_text = obj.text[input_id_index]
                    input_ids = self.tokenizer.encode(input_text)
                else:
                    input_text = (
                        obj.text[input_id_index] if obj.text is not None else None
                    )
                    input_ids = obj.input_ids[input_id_index]
Lianmin Zheng's avatar
Lianmin Zheng committed
221

222
223
224
                sampling_params = self._get_sampling_params(obj.sampling_params[index])
                if self.is_generation:
                    image_inputs = await self.image_processor.process_images_async(
225
                        obj.image_data[index], input_text or input_ids, obj
226
                    )
227
228
                    if image_inputs and "input_ids" in image_inputs:
                        input_ids = image_inputs["input_ids"]
229
230
231
                    return_logprob = obj.return_logprob[index]
                    logprob_start_len = obj.logprob_start_len[index]
                    top_logprobs_num = obj.top_logprobs_num[index]
232

233
            self._validate_input_length(input_ids)
234

yichuan~'s avatar
yichuan~ committed
235
        else:  # A prefill request to cache the common prompt for parallel sampling
236
            assert self.is_generation
yichuan~'s avatar
yichuan~ committed
237
238
            if obj.text is not None:
                if isinstance(obj.text, list):
239
                    input_text = obj.text[input_id_index]
yichuan~'s avatar
yichuan~ committed
240
241
242
243
                    rid = obj.rid[index]
                else:
                    input_text = obj.text
                    rid = obj.rid[0]
244
245
246
247
248
249
250
251
252
                if self.tokenizer is not None:
                    input_ids = self.tokenizer.encode(input_text)
                else:
                    assert obj.input_ids is not None
                    input_ids = obj.input_ids
                    if isinstance(obj.input_ids, list) and isinstance(
                        obj.input_ids[0], list
                    ):
                        # when obj["input_ids"] is List[List[int]]
253
                        input_ids = obj.input_ids[input_id_index]
254
255
256
257
                        rid = obj.rid[index]
                    else:
                        input_ids = obj.input_ids
                        rid = obj.rid[0]
258
            else:
yichuan~'s avatar
yichuan~ committed
259
260
261
262
263
                input_text = None
                if isinstance(obj.input_ids, list) and isinstance(
                    obj.input_ids[0], list
                ):
                    # when obj["input_ids"] is List[List[int]]
264
                    input_ids = obj.input_ids[input_id_index]
yichuan~'s avatar
yichuan~ committed
265
266
267
268
269
                    rid = obj.rid[index]
                else:
                    input_ids = obj.input_ids
                    rid = obj.rid[0]

270
271
            sampling_params = SamplingParams(**obj.sampling_params[0])
            sampling_params.max_new_tokens = 0
272
            image_inputs = await self.image_processor.process_images_async(
273
                obj.image_data[0], input_text or input_ids, obj
274
            )
275
276
            if image_inputs and "input_ids" in image_inputs:
                input_ids = image_inputs["input_ids"]
277
278
279
            return_logprob = obj.return_logprob[0]
            logprob_start_len = obj.logprob_start_len[0]
            top_logprobs_num = obj.top_logprobs_num[0]
Lianmin Zheng's avatar
Lianmin Zheng committed
280

281
        # Send to the controller
282
283
284
285
286
        if self.is_generation:
            tokenized_obj = TokenizedGenerateReqInput(
                rid,
                input_text,
                input_ids,
Liangsheng Yin's avatar
Liangsheng Yin committed
287
                image_inputs,
288
289
290
291
292
                sampling_params,
                return_logprob,
                logprob_start_len,
                top_logprobs_num,
                obj.stream,
293
                (
294
                    obj.lora_path[input_id_index]
295
296
297
                    if isinstance(obj.lora_path, list)
                    else obj.lora_path
                ),
298
            )
299
        elif isinstance(obj, EmbeddingReqInput):
300
301
302
303
304
305
            tokenized_obj = TokenizedEmbeddingReqInput(
                rid,
                input_text,
                input_ids,
                sampling_params,
            )
306
307
308
309
310
311
312
313
        else:
            assert isinstance(obj, RewardReqInput)
            tokenized_obj = TokenizedRewardReqInput(
                rid,
                input_text,
                input_ids,
                sampling_params,
            )
314

315
        self.send_to_scheduler.send_pyobj(tokenized_obj)
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
        return rid, input_ids

    async def _handle_single_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
        request: Optional[fastapi.Request] = None,
        index: Optional[int] = None,
        input_id_index: Optional[int] = None,
        is_cache_for_prefill: Optional[bool] = False,
    ):
        rid, input_ids = await self._send_single_request(
            obj,
            index,
            input_id_index=input_id_index,
            is_cache_for_prefill=is_cache_for_prefill,
        )
332

333
        # Recv results
334
335
336
        event = asyncio.Event()
        state = ReqState([], False, event)
        self.rid_to_state[rid] = state
337

338
        if not is_cache_for_prefill:
339
            async for response in self._wait_for_response(state, obj, rid, request):
340
                yield response
341
        else:
Ying Sheng's avatar
Ying Sheng committed
342
            assert self.is_generation
343
            await self._wait_for_cache_prefill_response(state, obj, rid, request)
344
            yield input_ids
345

346
    async def _handle_batch_request(
347
        self,
348
        obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
349
        request: Optional[fastapi.Request] = None,
350
    ):
351
        batch_size = obj.batch_size
352
353
354
355
        if self.is_generation:
            parallel_sample_num = obj.parallel_sample_num

            if parallel_sample_num != 1:
356
                # Send prefill requests to cache the common prefix
357
358
359
360
                parallel_sample_num += 1
                input_id_result = [] if obj.input_ids is None else None
                for i in range(batch_size):
                    async for input_id in self._handle_single_request(
361
362
363
364
365
                        obj,
                        request,
                        index=i,
                        input_id_index=i,
                        is_cache_for_prefill=True,
366
367
368
                    ):
                        if input_id_result is not None:
                            input_id_result.append(input_id)
369
                if input_id_result is not None:
370
371
372
                    obj.input_ids = input_id_result
        else:
            parallel_sample_num = 1
yichuan~'s avatar
yichuan~ committed
373

374
        # First send out all requests
375
        generators = []
376
377
378
        for i in range(batch_size):
            for j in range(parallel_sample_num):
                if j == 0 and parallel_sample_num != 1:
379
                    continue
380
381
                index = i * parallel_sample_num + j
                if parallel_sample_num != 1:
382
                    # Here when using parallel sampling we should consider prefill stage so the index is :  j + i * (parallel_sample_num-1) + batch_size - 1
383
                    index += batch_size - 1 - i
384

385
386
387
                rid, _ = await self._send_single_request(
                    obj, index, input_id_index=i, is_cache_for_prefill=False
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
388
389

                event = asyncio.Event()
390
                state = ReqState([], False, event)
Lianmin Zheng's avatar
Lianmin Zheng committed
391
                self.rid_to_state[rid] = state
392

393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
                generators.append(
                    self._wait_for_response(
                        state,
                        obj,
                        rid,
                        request,
                        index=index,
                        response_index=len(generators),
                    )
                )

        # Then process the responses based on streaming option
        is_stream = hasattr(obj, "stream") and obj.stream

        tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
408
        output_list = [None] * len(tasks)
409

410
        # Fetch results
411
412
413
414
        while tasks:
            done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)

            for task in done:
415
                cur_index = tasks.index(task)
416
417
418
419
420
421
422

                try:
                    result = task.result()

                    if is_stream:
                        yield result
                    else:
423
                        output_list[result["index"]] = result
424

425
426
                    tasks[cur_index] = asyncio.create_task(
                        generators[cur_index].__anext__()
Liangsheng Yin's avatar
Liangsheng Yin committed
427
                    )
428
                except StopAsyncIteration:
429
430
                    del generators[cur_index]
                    del tasks[cur_index]
431
432
433

        if not is_stream:
            yield output_list
434

435
    def _validate_input_length(self, input_ids: List[int]):
436
437
438
439
440
441
        if len(input_ids) >= self.context_len:
            raise ValueError(
                f"The input ({len(input_ids)} tokens) is longer than the "
                f"model's context length ({self.context_len} tokens)."
            )

442
    def _get_sampling_params(self, sampling_params_data: dict):
443
444
445
446
447
448
        sampling_params = SamplingParams(**sampling_params_data)
        if sampling_params.max_new_tokens != 0:
            sampling_params.normalize(self.tokenizer)
            sampling_params.verify()
        return sampling_params

449
450
451
    async def _wait_for_response(
        self,
        state: ReqState,
452
        obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
453
        rid: str,
454
455
        request: Optional[fastapi.Request] = None,
        index: Optional[int] = None,
456
        response_index: int = 0,
457
    ):
458
459
        while True:
            try:
460
                await asyncio.wait_for(state.event.wait(), timeout=4)
461
462
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
463
464
                    for rid in [obj.rid] if obj.is_single else obj.rid:
                        self.abort_request(rid)
465
466
467
                    raise ValueError(f"Abort request {rid}")
                continue

468
469
470
            if self.is_generation:
                out = self.convert_logprob_style(
                    state.out_list[-1],
471
472
473
474
475
476
                    obj.return_logprob if index is None else obj.return_logprob[index],
                    (
                        obj.top_logprobs_num
                        if index is None
                        else obj.top_logprobs_num[index]
                    ),
477
478
                    obj.return_text_in_logprobs,
                )
479
            else:  # isinstance(obj, (EmbeddingReqInput, RewardReqInput))
480
                out = state.out_list[-1]
481

482
483
            out["index"] = response_index

Ying Sheng's avatar
Ying Sheng committed
484
            # Log requests
485
            if self.server_args.log_requests and state.finished:
486
                logger.info(f"in={obj}, out={out}")
487
488
489
490
491
492
493

            state.out_list = []
            if state.finished:
                del self.rid_to_state[rid]
                yield out
                break

494
            state.event.clear()
495
496
            yield out

497
498
499
500
501
    async def _wait_for_cache_prefill_response(
        self,
        state: ReqState,
        obj: GenerateReqInput,
        rid: str,
502
        request: Optional[fastapi.Request] = None,
503
    ):
504
505
506
507
508
509
510
511
512
513
514
515
516
        while True:
            try:
                await asyncio.wait_for(state.event.wait(), timeout=4)
                break
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
                    for rid in obj.rid:
                        self.abort_request(rid)
                    raise ValueError(f"Abort request {rid}")
                continue

        assert state.finished
        del self.rid_to_state[rid]
Lianmin Zheng's avatar
Lianmin Zheng committed
517

518
519
    def flush_cache(self):
        req = FlushCacheReq()
520
        self.send_to_scheduler.send_pyobj(req)
Liangsheng Yin's avatar
Liangsheng Yin committed
521

522
523
524
525
526
    def abort_request(self, rid: str):
        if rid not in self.rid_to_state:
            return
        del self.rid_to_state[rid]
        req = AbortReq(rid)
527
        self.send_to_scheduler.send_pyobj(req)
528

529
530
531
532
533
534
535
536
    def start_profile(self):
        req = ProfileReq.START_PROFILE
        self.send_to_scheduler.send_pyobj(req)

    def stop_profile(self):
        req = ProfileReq.STOP_PROFILE
        self.send_to_scheduler.send_pyobj(req)

537
538
539
540
541
542
543
544
545
    async def get_memory_pool_size(self):
        if self.to_create_loop:
            self.create_handle_loop()

        req = GetMemPoolSizeReq()
        self.send_to_scheduler.send_pyobj(req)
        self.mem_pool_size = asyncio.Future()
        return await self.mem_pool_size

546
547
548
    async def update_weights(
        self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
    ):
549
550
551
552
553
554
555
556
557
558
559
        if self.to_create_loop:
            self.create_handle_loop()

        # default the load format to the server_args
        if obj.load_format is None:
            obj.load_format = self.server_args.load_format

        if not self.model_update_lock.locked():
            async with self.model_update_lock:
                # wait for the previous generation requests to finish
                while len(self.rid_to_state) > 0:
560
                    await asyncio.sleep(0.001)
561
                self.send_to_scheduler.send_pyobj(obj)
562
563
564
565
566
567
568
569
570
571
                self.model_update_result = asyncio.Future()
                result = await self.model_update_result
                if result.success:
                    self.server_args.model_path = obj.model_path
                    self.server_args.load_format = obj.load_format
                    self.model_path = obj.model_path
            return result.success, result.message
        else:
            return False, "Another update is in progress. Please try again later."

Lianmin Zheng's avatar
Lianmin Zheng committed
572
    def create_abort_task(self, obj: GenerateReqInput):
573
574
        # Abort the request if the client is disconnected.
        async def abort_request():
575
            await asyncio.sleep(1)
576
577
578
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
579
                for rid in obj.rid:
580
581
582
583
584
585
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

586
    def create_handle_loop(self):
587
588
589
        if not self.to_create_loop:
            return

Lianmin Zheng's avatar
Lianmin Zheng committed
590
591
592
593
594
        self.to_create_loop = False
        loop = asyncio.get_event_loop()
        loop.create_task(self.handle_loop())

    async def handle_loop(self):
595
596
        """The event loop that handles requests"""

Lianmin Zheng's avatar
Lianmin Zheng committed
597
        while True:
598
599
600
601
602
603
604
            recv_obj: Union[
                BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
            ] = await self.recv_from_detokenizer.recv_pyobj()

            if isinstance(recv_obj, UpdateWeightReqOutput):
                self.model_update_result.set_result(recv_obj)
                continue
605
606
607
            elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
                self.mem_pool_size.set_result(recv_obj)
                continue
608

609
610
611
            assert isinstance(
                recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
            ), f"Unexpected obj received: {type(recv_obj)}"
612

613
614
615
616
617
618
            for i, rid in enumerate(recv_obj.rids):
                state = self.rid_to_state.get(rid, None)
                if state is None:
                    continue

                recv_obj.meta_info[i]["id"] = rid
619
620
621
622
623
                if isinstance(recv_obj, BatchStrOut):
                    out_dict = {
                        "text": recv_obj.output_strs[i],
                        "meta_info": recv_obj.meta_info[i],
                    }
624
625
                elif isinstance(recv_obj, BatchTokenIDOut):
                    out_dict = {
626
                        "token_ids": recv_obj.output_ids[i],
627
628
629
                        "meta_info": recv_obj.meta_info[i],
                    }

630
631
632
633
634
635
                else:
                    assert isinstance(recv_obj, BatchEmbeddingOut)
                    out_dict = {
                        "embedding": recv_obj.embeddings[i],
                        "meta_info": recv_obj.meta_info[i],
                    }
636
637
638
                state.out_list.append(out_dict)
                state.finished = recv_obj.finished_reason[i] is not None
                state.event.set()
639

Liangsheng Yin's avatar
Liangsheng Yin committed
640
    def convert_logprob_style(
641
642
643
644
645
        self,
        ret: dict,
        return_logprob: bool,
        top_logprobs_num: int,
        return_text_in_logprobs: bool,
Liangsheng Yin's avatar
Liangsheng Yin committed
646
    ):
647
        if return_logprob:
648
649
            ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
                ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
650
            )
651
652
            ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
                ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
653
            )
654
655

            if top_logprobs_num > 0:
656
                ret["meta_info"]["input_top_logprobs"] = (
zhyncs's avatar
zhyncs committed
657
                    self.detokenize_top_logprobs_tokens(
658
                        ret["meta_info"]["input_top_logprobs"],
zhyncs's avatar
zhyncs committed
659
660
                        return_text_in_logprobs,
                    )
661
                )
662
                ret["meta_info"]["output_top_logprobs"] = (
zhyncs's avatar
zhyncs committed
663
                    self.detokenize_top_logprobs_tokens(
664
                        ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
zhyncs's avatar
zhyncs committed
665
                    )
666
                )
667
668
        return ret

669
670
671
    def detokenize_logprob_tokens(
        self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
    ):
672
        # TODO(lianmin): This should run on DetokenizerManager
673
674
675
        if not decode_to_text:
            return [(logprob, token_id, None) for logprob, token_id in token_logprobs]

676
        assert self.tokenizer is not None
677
678
679
680
681
682
683
        token_ids = [tid for _, tid in token_logprobs]
        token_texts = self.tokenizer.batch_decode(token_ids)
        return [
            (logprob, token_id, token_text)
            for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
        ]

684
    def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
685
686
687
688
689
690
691
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
        for i, token_top_logprobs in enumerate(top_logprobs):
            if token_top_logprobs:
                top_logprobs[i] = self.detokenize_logprob_tokens(
                    token_top_logprobs, decode_to_text
                )
692
        return top_logprobs