tokenizer_manager.py 25.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""TokenizerManager is a process that tokenizes the text."""
17

Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
import asyncio
import dataclasses
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import json
21
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
22
import os
23
from typing import Dict, List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
import fastapi
Lianmin Zheng's avatar
Lianmin Zheng committed
26
27
28
import uvloop
import zmq
import zmq.asyncio
29
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
30

Lianmin Zheng's avatar
Lianmin Zheng committed
31
32
33
34
35
36
from sglang.srt.hf_transformers_utils import (
    get_config,
    get_context_length,
    get_processor,
    get_tokenizer,
)
37
38
39
40
from sglang.srt.managers.image_processor import (
    get_dummy_image_processor,
    get_image_processor,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
41
from sglang.srt.managers.io_struct import (
42
    AbortReq,
43
    BatchEmbeddingOut,
Lianmin Zheng's avatar
Lianmin Zheng committed
44
    BatchStrOut,
45
    BatchTokenIDOut,
46
    EmbeddingReqInput,
47
    FlushCacheReq,
Lianmin Zheng's avatar
Lianmin Zheng committed
48
    GenerateReqInput,
49
50
    GetMemPoolSizeReq,
    GetMemPoolSizeReqOutput,
51
    ProfileReq,
52
    RewardReqInput,
53
    TokenizedEmbeddingReqInput,
Lianmin Zheng's avatar
Lianmin Zheng committed
54
    TokenizedGenerateReqInput,
55
    TokenizedRewardReqInput,
56
57
    UpdateWeightReqInput,
    UpdateWeightReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
58
)
59
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
60
from sglang.srt.server_args import PortArgs, ServerArgs
61
from sglang.srt.utils import is_generation_model, is_multimodal_model
Lianmin Zheng's avatar
Lianmin Zheng committed
62
63
64

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

65
66
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
67
68
69

@dataclasses.dataclass
class ReqState:
70
71
    """Store the state a request."""

Lianmin Zheng's avatar
Lianmin Zheng committed
72
73
74
75
76
77
    out_list: List
    finished: bool
    event: asyncio.Event


class TokenizerManager:
78
79
    """TokenizerManager is a process that tokenizes the text."""

Lianmin Zheng's avatar
Lianmin Zheng committed
80
81
82
83
84
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
Liangsheng Yin's avatar
Liangsheng Yin committed
85
86
        self.server_args = server_args

87
        # Init inter-process communication
Lianmin Zheng's avatar
Lianmin Zheng committed
88
89
        context = zmq.asyncio.Context(2)
        self.recv_from_detokenizer = context.socket(zmq.PULL)
90
        self.recv_from_detokenizer.bind(f"ipc://{port_args.tokenizer_ipc_name}")
Lianmin Zheng's avatar
Lianmin Zheng committed
91

92
        self.send_to_scheduler = context.socket(zmq.PUSH)
93
        self.send_to_scheduler.connect(f"ipc://{port_args.scheduler_input_ipc_name}")
Lianmin Zheng's avatar
Lianmin Zheng committed
94

95
        # Read model args
Lianmin Zheng's avatar
Lianmin Zheng committed
96
        self.model_path = server_args.model_path
97
        self.served_model_name = server_args.served_model_name
Lianmin Zheng's avatar
Lianmin Zheng committed
98
        self.hf_config = get_config(
Yuanhan Zhang's avatar
Yuanhan Zhang committed
99
100
            self.model_path,
            trust_remote_code=server_args.trust_remote_code,
Lianmin Zheng's avatar
Lianmin Zheng committed
101
            model_override_args=json.loads(server_args.json_model_override_args),
Lianmin Zheng's avatar
Lianmin Zheng committed
102
        )
103
104
105
        self.is_generation = is_generation_model(
            self.hf_config.architectures, self.server_args.is_embedding
        )
106
107
108
        self.context_len = server_args.context_length or get_context_length(
            self.hf_config
        )
109
110
        # Create image processor placeholder
        self.image_processor = get_dummy_image_processor()
Lianmin Zheng's avatar
Lianmin Zheng committed
111

112
        # Create tokenizer
113
114
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
Lianmin Zheng's avatar
Lianmin Zheng committed
115
        else:
116
            if is_multimodal_model(self.hf_config.architectures):
117
118
119
120
121
122
123
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
124

125
126
                # We want to parallelize the image pre-processing so we create an executor for it
                self.image_processor = get_image_processor(
127
                    self.hf_config, server_args, self.processor
128
129
130
131
132
133
134
                )
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
135

136
        # Store states
Lianmin Zheng's avatar
Lianmin Zheng committed
137
        self.to_create_loop = True
138
        self.rid_to_state: Dict[str, ReqState] = {}
Lianmin Zheng's avatar
Lianmin Zheng committed
139

140
        # For update model weights
141
142
143
        self.model_update_lock = asyncio.Lock()
        self.model_update_result = None

144
    async def generate_request(
145
        self,
146
        obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
147
        request: Optional[fastapi.Request] = None,
148
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
149
        if self.to_create_loop:
150
            self.create_handle_loop()
Lianmin Zheng's avatar
Lianmin Zheng committed
151

152
        while self.model_update_lock.locked():
153
            await asyncio.sleep(0.001)
154

155
156
157
158
159
        if isinstance(obj, EmbeddingReqInput) and self.is_generation:
            raise ValueError(
                "This model does not appear to be an embedding model by default. Please add `--is-embedding` when launching the server or try another model."
            )

160
        obj.post_init()
161
        is_single = obj.is_single
162
163
164
165
166
167
        if is_single:
            async for response in self._handle_single_request(obj, request):
                yield response
        else:
            async for response in self._handle_batch_request(obj, request):
                yield response
168

169
    async def _send_single_request(
170
        self,
171
        obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
172
        index: Optional[int] = None,
173
        input_id_index: Optional[int] = None,
174
        is_cache_for_prefill: Optional[bool] = False,
175
    ):
yichuan~'s avatar
yichuan~ committed
176
        if not is_cache_for_prefill:  # The normal case with a single prompt
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
            if index is None:
                rid = obj.rid
                if hasattr(obj, "conv"):
                    # reward model
                    conv = obj.conv
                    input_text = self.tokenizer.apply_chat_template(
                        conv, tokenize=False
                    )
                    input_ids = self.tokenizer.encode(input_text)
                elif obj.input_ids is None:
                    input_text = obj.text
                    input_ids = self.tokenizer.encode(input_text)
                else:
                    input_text = obj.text if obj.text is not None else None
                    input_ids = obj.input_ids

                sampling_params = self._get_sampling_params(obj.sampling_params)
                if self.is_generation:
                    image_inputs = await self.image_processor.process_images_async(
196
                        obj.image_data, input_text or input_ids, obj
197
                    )
198
199
                    if image_inputs and "input_ids" in image_inputs:
                        input_ids = image_inputs["input_ids"]
200
201
202
                    return_logprob = obj.return_logprob
                    logprob_start_len = obj.logprob_start_len
                    top_logprobs_num = obj.top_logprobs_num
203
            else:
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
                rid = obj.rid[index]
                if hasattr(obj, "conv"):
                    # reward model
                    conv = obj.conv[index]
                    input_text = self.tokenizer.apply_chat_template(
                        conv, tokenize=False
                    )
                    input_ids = self.tokenizer.encode(input_text)
                elif obj.input_ids is None:
                    input_text = obj.text[input_id_index]
                    input_ids = self.tokenizer.encode(input_text)
                else:
                    input_text = (
                        obj.text[input_id_index] if obj.text is not None else None
                    )
                    input_ids = obj.input_ids[input_id_index]
Lianmin Zheng's avatar
Lianmin Zheng committed
220

221
222
223
                sampling_params = self._get_sampling_params(obj.sampling_params[index])
                if self.is_generation:
                    image_inputs = await self.image_processor.process_images_async(
224
                        obj.image_data[index], input_text or input_ids, obj
225
                    )
226
227
                    if image_inputs and "input_ids" in image_inputs:
                        input_ids = image_inputs["input_ids"]
228
229
230
                    return_logprob = obj.return_logprob[index]
                    logprob_start_len = obj.logprob_start_len[index]
                    top_logprobs_num = obj.top_logprobs_num[index]
231

232
            self._validate_input_length(input_ids)
233

yichuan~'s avatar
yichuan~ committed
234
        else:  # A prefill request to cache the common prompt for parallel sampling
235
            assert self.is_generation
yichuan~'s avatar
yichuan~ committed
236
237
            if obj.text is not None:
                if isinstance(obj.text, list):
238
                    input_text = obj.text[input_id_index]
yichuan~'s avatar
yichuan~ committed
239
240
241
242
                    rid = obj.rid[index]
                else:
                    input_text = obj.text
                    rid = obj.rid[0]
243
244
245
246
247
248
249
250
251
                if self.tokenizer is not None:
                    input_ids = self.tokenizer.encode(input_text)
                else:
                    assert obj.input_ids is not None
                    input_ids = obj.input_ids
                    if isinstance(obj.input_ids, list) and isinstance(
                        obj.input_ids[0], list
                    ):
                        # when obj["input_ids"] is List[List[int]]
252
                        input_ids = obj.input_ids[input_id_index]
253
254
255
256
                        rid = obj.rid[index]
                    else:
                        input_ids = obj.input_ids
                        rid = obj.rid[0]
257
            else:
yichuan~'s avatar
yichuan~ committed
258
259
260
261
262
                input_text = None
                if isinstance(obj.input_ids, list) and isinstance(
                    obj.input_ids[0], list
                ):
                    # when obj["input_ids"] is List[List[int]]
263
                    input_ids = obj.input_ids[input_id_index]
yichuan~'s avatar
yichuan~ committed
264
265
266
267
268
                    rid = obj.rid[index]
                else:
                    input_ids = obj.input_ids
                    rid = obj.rid[0]

269
270
            sampling_params = SamplingParams(**obj.sampling_params[0])
            sampling_params.max_new_tokens = 0
271
            image_inputs = await self.image_processor.process_images_async(
272
                obj.image_data[0], input_text or input_ids, obj
273
            )
274
275
            if image_inputs and "input_ids" in image_inputs:
                input_ids = image_inputs["input_ids"]
276
277
278
            return_logprob = obj.return_logprob[0]
            logprob_start_len = obj.logprob_start_len[0]
            top_logprobs_num = obj.top_logprobs_num[0]
Lianmin Zheng's avatar
Lianmin Zheng committed
279

280
        # Send to the controller
281
282
283
284
285
        if self.is_generation:
            tokenized_obj = TokenizedGenerateReqInput(
                rid,
                input_text,
                input_ids,
Liangsheng Yin's avatar
Liangsheng Yin committed
286
                image_inputs,
287
288
289
290
291
                sampling_params,
                return_logprob,
                logprob_start_len,
                top_logprobs_num,
                obj.stream,
292
                (
293
                    obj.lora_path[input_id_index]
294
295
296
                    if isinstance(obj.lora_path, list)
                    else obj.lora_path
                ),
297
            )
298
        elif isinstance(obj, EmbeddingReqInput):
299
300
301
302
303
304
            tokenized_obj = TokenizedEmbeddingReqInput(
                rid,
                input_text,
                input_ids,
                sampling_params,
            )
305
306
307
308
309
310
311
312
        else:
            assert isinstance(obj, RewardReqInput)
            tokenized_obj = TokenizedRewardReqInput(
                rid,
                input_text,
                input_ids,
                sampling_params,
            )
313

314
        self.send_to_scheduler.send_pyobj(tokenized_obj)
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
        return rid, input_ids

    async def _handle_single_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
        request: Optional[fastapi.Request] = None,
        index: Optional[int] = None,
        input_id_index: Optional[int] = None,
        is_cache_for_prefill: Optional[bool] = False,
    ):
        rid, input_ids = await self._send_single_request(
            obj,
            index,
            input_id_index=input_id_index,
            is_cache_for_prefill=is_cache_for_prefill,
        )
331

332
        # Recv results
333
334
335
        event = asyncio.Event()
        state = ReqState([], False, event)
        self.rid_to_state[rid] = state
336

337
        if not is_cache_for_prefill:
338
            async for response in self._wait_for_response(state, obj, rid, request):
339
                yield response
340
        else:
Ying Sheng's avatar
Ying Sheng committed
341
            assert self.is_generation
342
            await self._wait_for_cache_prefill_response(state, obj, rid, request)
343
            yield input_ids
344

345
    async def _handle_batch_request(
346
        self,
347
        obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
348
        request: Optional[fastapi.Request] = None,
349
    ):
350
        batch_size = obj.batch_size
351
352
353
354
        if self.is_generation:
            parallel_sample_num = obj.parallel_sample_num

            if parallel_sample_num != 1:
355
                # Send prefill requests to cache the common prefix
356
357
358
359
                parallel_sample_num += 1
                input_id_result = [] if obj.input_ids is None else None
                for i in range(batch_size):
                    async for input_id in self._handle_single_request(
360
361
362
363
364
                        obj,
                        request,
                        index=i,
                        input_id_index=i,
                        is_cache_for_prefill=True,
365
366
367
                    ):
                        if input_id_result is not None:
                            input_id_result.append(input_id)
368
                if input_id_result is not None:
369
370
371
                    obj.input_ids = input_id_result
        else:
            parallel_sample_num = 1
yichuan~'s avatar
yichuan~ committed
372

373
        # First send out all requests
374
        generators = []
375
376
377
        for i in range(batch_size):
            for j in range(parallel_sample_num):
                if j == 0 and parallel_sample_num != 1:
378
                    continue
379
380
                index = i * parallel_sample_num + j
                if parallel_sample_num != 1:
381
                    # Here when using parallel sampling we should consider prefill stage so the index is :  j + i * (parallel_sample_num-1) + batch_size - 1
382
                    index += batch_size - 1 - i
383

384
385
386
                rid, _ = await self._send_single_request(
                    obj, index, input_id_index=i, is_cache_for_prefill=False
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
387
388

                event = asyncio.Event()
389
                state = ReqState([], False, event)
Lianmin Zheng's avatar
Lianmin Zheng committed
390
                self.rid_to_state[rid] = state
391

392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
                generators.append(
                    self._wait_for_response(
                        state,
                        obj,
                        rid,
                        request,
                        index=index,
                        response_index=len(generators),
                    )
                )

        # Then process the responses based on streaming option
        is_stream = hasattr(obj, "stream") and obj.stream

        tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
407
        output_list = [None] * len(tasks)
408

409
        # Fetch results
410
411
412
413
        while tasks:
            done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)

            for task in done:
414
                cur_index = tasks.index(task)
415
416
417
418
419
420
421

                try:
                    result = task.result()

                    if is_stream:
                        yield result
                    else:
422
                        output_list[result["index"]] = result
423

424
425
                    tasks[cur_index] = asyncio.create_task(
                        generators[cur_index].__anext__()
Liangsheng Yin's avatar
Liangsheng Yin committed
426
                    )
427
                except StopAsyncIteration:
428
429
                    del generators[cur_index]
                    del tasks[cur_index]
430
431
432

        if not is_stream:
            yield output_list
433

434
    def _validate_input_length(self, input_ids: List[int]):
435
436
437
438
439
440
        if len(input_ids) >= self.context_len:
            raise ValueError(
                f"The input ({len(input_ids)} tokens) is longer than the "
                f"model's context length ({self.context_len} tokens)."
            )

441
    def _get_sampling_params(self, sampling_params_data: dict):
442
443
444
445
446
447
        sampling_params = SamplingParams(**sampling_params_data)
        if sampling_params.max_new_tokens != 0:
            sampling_params.normalize(self.tokenizer)
            sampling_params.verify()
        return sampling_params

448
449
450
    async def _wait_for_response(
        self,
        state: ReqState,
451
        obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
452
        rid: str,
453
454
        request: Optional[fastapi.Request] = None,
        index: Optional[int] = None,
455
        response_index: int = 0,
456
    ):
457
458
        while True:
            try:
459
                await asyncio.wait_for(state.event.wait(), timeout=4)
460
461
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
462
463
                    for rid in [obj.rid] if obj.is_single else obj.rid:
                        self.abort_request(rid)
464
465
466
                    raise ValueError(f"Abort request {rid}")
                continue

467
468
469
            if self.is_generation:
                out = self.convert_logprob_style(
                    state.out_list[-1],
470
471
472
473
474
475
                    obj.return_logprob if index is None else obj.return_logprob[index],
                    (
                        obj.top_logprobs_num
                        if index is None
                        else obj.top_logprobs_num[index]
                    ),
476
477
                    obj.return_text_in_logprobs,
                )
478
            else:  # isinstance(obj, (EmbeddingReqInput, RewardReqInput))
479
                out = state.out_list[-1]
480

481
482
            out["index"] = response_index

Ying Sheng's avatar
Ying Sheng committed
483
            # Log requests
484
            if self.server_args.log_requests and state.finished:
485
                logger.info(f"in={obj}, out={out}")
486
487
488
489
490
491
492

            state.out_list = []
            if state.finished:
                del self.rid_to_state[rid]
                yield out
                break

493
            state.event.clear()
494
495
            yield out

496
497
498
499
500
    async def _wait_for_cache_prefill_response(
        self,
        state: ReqState,
        obj: GenerateReqInput,
        rid: str,
501
        request: Optional[fastapi.Request] = None,
502
    ):
503
504
505
506
507
508
509
510
511
512
513
514
515
        while True:
            try:
                await asyncio.wait_for(state.event.wait(), timeout=4)
                break
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
                    for rid in obj.rid:
                        self.abort_request(rid)
                    raise ValueError(f"Abort request {rid}")
                continue

        assert state.finished
        del self.rid_to_state[rid]
Lianmin Zheng's avatar
Lianmin Zheng committed
516

517
518
    def flush_cache(self):
        req = FlushCacheReq()
519
        self.send_to_scheduler.send_pyobj(req)
Liangsheng Yin's avatar
Liangsheng Yin committed
520

521
522
523
524
525
    def abort_request(self, rid: str):
        if rid not in self.rid_to_state:
            return
        del self.rid_to_state[rid]
        req = AbortReq(rid)
526
        self.send_to_scheduler.send_pyobj(req)
527

528
529
530
531
532
533
534
535
    def start_profile(self):
        req = ProfileReq.START_PROFILE
        self.send_to_scheduler.send_pyobj(req)

    def stop_profile(self):
        req = ProfileReq.STOP_PROFILE
        self.send_to_scheduler.send_pyobj(req)

536
537
538
539
540
541
542
543
544
    async def get_memory_pool_size(self):
        if self.to_create_loop:
            self.create_handle_loop()

        req = GetMemPoolSizeReq()
        self.send_to_scheduler.send_pyobj(req)
        self.mem_pool_size = asyncio.Future()
        return await self.mem_pool_size

545
546
547
    async def update_weights(
        self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
    ):
548
549
550
551
552
553
554
555
556
557
558
        if self.to_create_loop:
            self.create_handle_loop()

        # default the load format to the server_args
        if obj.load_format is None:
            obj.load_format = self.server_args.load_format

        if not self.model_update_lock.locked():
            async with self.model_update_lock:
                # wait for the previous generation requests to finish
                while len(self.rid_to_state) > 0:
559
                    await asyncio.sleep(0.001)
560
                self.send_to_scheduler.send_pyobj(obj)
561
562
563
564
565
566
567
568
569
570
                self.model_update_result = asyncio.Future()
                result = await self.model_update_result
                if result.success:
                    self.server_args.model_path = obj.model_path
                    self.server_args.load_format = obj.load_format
                    self.model_path = obj.model_path
            return result.success, result.message
        else:
            return False, "Another update is in progress. Please try again later."

Lianmin Zheng's avatar
Lianmin Zheng committed
571
    def create_abort_task(self, obj: GenerateReqInput):
572
573
574
575
576
577
        # Abort the request if the client is disconnected.
        async def abort_request():
            await asyncio.sleep(3)
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
578
                for rid in obj.rid:
579
580
581
582
583
584
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

585
    def create_handle_loop(self):
586
587
588
        if not self.to_create_loop:
            return

Lianmin Zheng's avatar
Lianmin Zheng committed
589
590
591
592
593
        self.to_create_loop = False
        loop = asyncio.get_event_loop()
        loop.create_task(self.handle_loop())

    async def handle_loop(self):
594
595
        """The event loop that handles requests"""

Lianmin Zheng's avatar
Lianmin Zheng committed
596
        while True:
597
598
599
600
601
602
603
            recv_obj: Union[
                BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
            ] = await self.recv_from_detokenizer.recv_pyobj()

            if isinstance(recv_obj, UpdateWeightReqOutput):
                self.model_update_result.set_result(recv_obj)
                continue
604
605
606
            elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
                self.mem_pool_size.set_result(recv_obj)
                continue
607

608
609
610
            assert isinstance(
                recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
            ), f"Unexpected obj received: {type(recv_obj)}"
611

612
613
614
615
616
617
            for i, rid in enumerate(recv_obj.rids):
                state = self.rid_to_state.get(rid, None)
                if state is None:
                    continue

                recv_obj.meta_info[i]["id"] = rid
618
619
620
621
622
                if isinstance(recv_obj, BatchStrOut):
                    out_dict = {
                        "text": recv_obj.output_strs[i],
                        "meta_info": recv_obj.meta_info[i],
                    }
623
624
625
626
627
628
629
630
631
                elif isinstance(recv_obj, BatchTokenIDOut):
                    read_start = 0 if i == 0 else recv_obj.read_offsets[i - 1]
                    out_dict = {
                        "token_ids": recv_obj.decode_ids[
                            read_start : recv_obj.read_offsets[i]
                        ],
                        "meta_info": recv_obj.meta_info[i],
                    }

632
633
634
635
636
637
                else:
                    assert isinstance(recv_obj, BatchEmbeddingOut)
                    out_dict = {
                        "embedding": recv_obj.embeddings[i],
                        "meta_info": recv_obj.meta_info[i],
                    }
638
639
640
                state.out_list.append(out_dict)
                state.finished = recv_obj.finished_reason[i] is not None
                state.event.set()
641

Liangsheng Yin's avatar
Liangsheng Yin committed
642
    def convert_logprob_style(
643
644
645
646
647
        self,
        ret: dict,
        return_logprob: bool,
        top_logprobs_num: int,
        return_text_in_logprobs: bool,
Liangsheng Yin's avatar
Liangsheng Yin committed
648
    ):
649
        if return_logprob:
650
651
            ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
                ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
652
            )
653
654
            ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
                ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
655
            )
656
657

            if top_logprobs_num > 0:
658
                ret["meta_info"]["input_top_logprobs"] = (
zhyncs's avatar
zhyncs committed
659
                    self.detokenize_top_logprobs_tokens(
660
                        ret["meta_info"]["input_top_logprobs"],
zhyncs's avatar
zhyncs committed
661
662
                        return_text_in_logprobs,
                    )
663
                )
664
                ret["meta_info"]["output_top_logprobs"] = (
zhyncs's avatar
zhyncs committed
665
                    self.detokenize_top_logprobs_tokens(
666
                        ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
zhyncs's avatar
zhyncs committed
667
                    )
668
                )
669
670
        return ret

671
672
673
    def detokenize_logprob_tokens(
        self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
    ):
674
        # TODO(lianmin): This should run on DetokenizerManager
675
676
677
        if not decode_to_text:
            return [(logprob, token_id, None) for logprob, token_id in token_logprobs]

678
        assert self.tokenizer is not None
679
680
681
682
683
684
685
        token_ids = [tid for _, tid in token_logprobs]
        token_texts = self.tokenizer.batch_decode(token_ids)
        return [
            (logprob, token_id, token_text)
            for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
        ]

686
    def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
687
688
689
690
691
692
693
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
        for i, token_top_logprobs in enumerate(top_logprobs):
            if token_top_logprobs:
                top_logprobs[i] = self.detokenize_logprob_tokens(
                    token_top_logprobs, decode_to_text
                )
694
        return top_logprobs