tokenizer_manager.py 25.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
"""TokenizerManager is a process that tokenizes the text."""
17

Lianmin Zheng's avatar
Lianmin Zheng committed
18
19
import asyncio
import dataclasses
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import json
21
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
22
import os
23
from typing import Dict, List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
24

25
import fastapi
Lianmin Zheng's avatar
Lianmin Zheng committed
26
27
28
import uvloop
import zmq
import zmq.asyncio
29
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
30

Lianmin Zheng's avatar
Lianmin Zheng committed
31
32
33
34
35
36
from sglang.srt.hf_transformers_utils import (
    get_config,
    get_context_length,
    get_processor,
    get_tokenizer,
)
37
38
39
40
from sglang.srt.managers.image_processor import (
    get_dummy_image_processor,
    get_image_processor,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
41
from sglang.srt.managers.io_struct import (
42
    AbortReq,
43
    BatchEmbeddingOut,
Lianmin Zheng's avatar
Lianmin Zheng committed
44
    BatchStrOut,
45
    BatchTokenIDOut,
46
    EmbeddingReqInput,
47
    FlushCacheReq,
Lianmin Zheng's avatar
Lianmin Zheng committed
48
    GenerateReqInput,
49
    RewardReqInput,
50
    TokenizedEmbeddingReqInput,
Lianmin Zheng's avatar
Lianmin Zheng committed
51
    TokenizedGenerateReqInput,
52
    TokenizedRewardReqInput,
53
54
    UpdateWeightReqInput,
    UpdateWeightReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
55
)
56
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
57
from sglang.srt.server_args import PortArgs, ServerArgs
58
from sglang.srt.utils import is_generation_model, is_multimodal_model
Lianmin Zheng's avatar
Lianmin Zheng committed
59
60
61

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

62
63
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
64
65
66

@dataclasses.dataclass
class ReqState:
67
68
    """Store the state a request."""

Lianmin Zheng's avatar
Lianmin Zheng committed
69
70
71
72
73
74
    out_list: List
    finished: bool
    event: asyncio.Event


class TokenizerManager:
75
76
    """TokenizerManager is a process that tokenizes the text."""

Lianmin Zheng's avatar
Lianmin Zheng committed
77
78
79
80
81
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
Liangsheng Yin's avatar
Liangsheng Yin committed
82
83
        self.server_args = server_args

84
        # Init inter-process communication
Lianmin Zheng's avatar
Lianmin Zheng committed
85
86
87
88
        context = zmq.asyncio.Context(2)
        self.recv_from_detokenizer = context.socket(zmq.PULL)
        self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")

89
90
        self.send_to_scheduler = context.socket(zmq.PUSH)
        self.send_to_scheduler.connect(f"tcp://127.0.0.1:{port_args.scheduler_port}")
Lianmin Zheng's avatar
Lianmin Zheng committed
91

92
        # Read model args
Lianmin Zheng's avatar
Lianmin Zheng committed
93
        self.model_path = server_args.model_path
94
        self.served_model_name = server_args.served_model_name
Lianmin Zheng's avatar
Lianmin Zheng committed
95
        self.hf_config = get_config(
Yuanhan Zhang's avatar
Yuanhan Zhang committed
96
97
            self.model_path,
            trust_remote_code=server_args.trust_remote_code,
Lianmin Zheng's avatar
Lianmin Zheng committed
98
            model_override_args=json.loads(server_args.json_model_override_args),
Lianmin Zheng's avatar
Lianmin Zheng committed
99
        )
100
101
102
        self.is_generation = is_generation_model(
            self.hf_config.architectures, self.server_args.is_embedding
        )
103
104
105
        self.context_len = server_args.context_length or get_context_length(
            self.hf_config
        )
106
107
        # Create image processor placeholder
        self.image_processor = get_dummy_image_processor()
Lianmin Zheng's avatar
Lianmin Zheng committed
108

109
        # Create tokenizer
110
111
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
Lianmin Zheng's avatar
Lianmin Zheng committed
112
        else:
113
            if is_multimodal_model(self.hf_config.architectures):
114
115
116
117
118
119
120
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
121

122
123
124
                # We want to parallelize the image pre-processing so we create an executor for it
                self.image_processor = get_image_processor(
                    self.hf_config, server_args, self.processor.image_processor
125
126
127
128
129
130
131
                )
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
132

133
        # Store states
Lianmin Zheng's avatar
Lianmin Zheng committed
134
        self.to_create_loop = True
135
        self.rid_to_state: Dict[str, ReqState] = {}
Lianmin Zheng's avatar
Lianmin Zheng committed
136

137
        # For update model weights
138
139
140
        self.model_update_lock = asyncio.Lock()
        self.model_update_result = None

141
    async def generate_request(
142
        self,
143
        obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
144
        request: Optional[fastapi.Request] = None,
145
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
146
        if self.to_create_loop:
147
            self.create_handle_loop()
Lianmin Zheng's avatar
Lianmin Zheng committed
148

149
        while self.model_update_lock.locked():
150
            await asyncio.sleep(0.001)
151

152
        obj.post_init()
153
154
        is_single = obj.is_single

155
156
157
158
159
160
        if is_single:
            async for response in self._handle_single_request(obj, request):
                yield response
        else:
            async for response in self._handle_batch_request(obj, request):
                yield response
161

162
    async def _handle_single_request(
163
        self,
164
        obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
165
        request: Optional[fastapi.Request] = None,
166
167
        index: Optional[int] = None,
        is_cache_for_prefill: Optional[bool] = False,
168
    ):
yichuan~'s avatar
yichuan~ committed
169
170
171
        if not is_cache_for_prefill:  # The normal case with a single prompt
            not_use_index = index is None

172
173
            rid = obj.rid if not_use_index else obj.rid[index]
            input_text = obj.text if not_use_index else obj.text[index]
174
175
176
177
178
179
180
            if hasattr(obj, "conv"):
                # reward model
                assert self.tokenizer is not None
                conv = obj.conv if not_use_index else obj.conv[index]
                input_text = self.tokenizer.apply_chat_template(conv, tokenize=False)
                input_ids = self.tokenizer.encode(input_text)
            elif obj.input_ids is None:
181
                assert self.tokenizer is not None
182
183
184
                input_ids = self.tokenizer.encode(input_text)
            else:
                input_ids = obj.input_ids if not_use_index else obj.input_ids[index]
Lianmin Zheng's avatar
Lianmin Zheng committed
185

186
            self._validate_input_length(input_ids)
187

188
            sampling_params = self._get_sampling_params(
189
                obj.sampling_params if not_use_index else obj.sampling_params[index]
190
            )
191
192

            if self.is_generation:
193
194
                image_inputs = await self.image_processor.process_images_async(
                    obj.image_data if not_use_index else obj.image_data[index], obj
195
196
197
198
199
200
201
202
203
204
205
206
207
208
                )
                return_logprob = (
                    obj.return_logprob if not_use_index else obj.return_logprob[index]
                )
                logprob_start_len = (
                    obj.logprob_start_len
                    if not_use_index
                    else obj.logprob_start_len[index]
                )
                top_logprobs_num = (
                    obj.top_logprobs_num
                    if not_use_index
                    else obj.top_logprobs_num[index]
                )
yichuan~'s avatar
yichuan~ committed
209
        else:  # A prefill request to cache the common prompt for parallel sampling
210
            assert self.is_generation
yichuan~'s avatar
yichuan~ committed
211
212
213
214
215
216
217
            if obj.text is not None:
                if isinstance(obj.text, list):
                    input_text = obj.text[index]
                    rid = obj.rid[index]
                else:
                    input_text = obj.text
                    rid = obj.rid[0]
218
219
220
221
222
223
224
225
226
227
228
229
230
231
                if self.tokenizer is not None:
                    input_ids = self.tokenizer.encode(input_text)
                else:
                    assert obj.input_ids is not None
                    input_ids = obj.input_ids
                    if isinstance(obj.input_ids, list) and isinstance(
                        obj.input_ids[0], list
                    ):
                        # when obj["input_ids"] is List[List[int]]
                        input_ids = obj.input_ids[index]
                        rid = obj.rid[index]
                    else:
                        input_ids = obj.input_ids
                        rid = obj.rid[0]
232
            else:
yichuan~'s avatar
yichuan~ committed
233
234
235
236
237
238
239
240
241
242
243
                input_text = None
                if isinstance(obj.input_ids, list) and isinstance(
                    obj.input_ids[0], list
                ):
                    # when obj["input_ids"] is List[List[int]]
                    input_ids = obj.input_ids[index]
                    rid = obj.rid[index]
                else:
                    input_ids = obj.input_ids
                    rid = obj.rid[0]

244
245
            sampling_params = SamplingParams(**obj.sampling_params[0])
            sampling_params.max_new_tokens = 0
246
247
248
            image_inputs = await self.image_processor.process_images_async(
                obj.image_data[0], obj
            )
249
250
251
            return_logprob = obj.return_logprob[0]
            logprob_start_len = obj.logprob_start_len[0]
            top_logprobs_num = obj.top_logprobs_num[0]
Lianmin Zheng's avatar
Lianmin Zheng committed
252

253
        # Send to the controller
254
255
256
257
258
        if self.is_generation:
            tokenized_obj = TokenizedGenerateReqInput(
                rid,
                input_text,
                input_ids,
Liangsheng Yin's avatar
Liangsheng Yin committed
259
                image_inputs,
260
261
262
263
264
                sampling_params,
                return_logprob,
                logprob_start_len,
                top_logprobs_num,
                obj.stream,
265
266
267
268
269
                (
                    obj.lora_path[index]
                    if isinstance(obj.lora_path, list)
                    else obj.lora_path
                ),
270
            )
271
        elif isinstance(obj, EmbeddingReqInput):
272
273
274
275
276
277
            tokenized_obj = TokenizedEmbeddingReqInput(
                rid,
                input_text,
                input_ids,
                sampling_params,
            )
278
279
280
281
282
283
284
285
        else:
            assert isinstance(obj, RewardReqInput)
            tokenized_obj = TokenizedRewardReqInput(
                rid,
                input_text,
                input_ids,
                sampling_params,
            )
286
        self.send_to_scheduler.send_pyobj(tokenized_obj)
287

288
        # Recv results
289
290
291
        event = asyncio.Event()
        state = ReqState([], False, event)
        self.rid_to_state[rid] = state
292
        if not is_cache_for_prefill:
293
            async for response in self._wait_for_response(state, obj, rid, request):
294
                yield response
295
        else:
Ying Sheng's avatar
Ying Sheng committed
296
            assert self.is_generation
297
            await self._wait_for_cache_prefill_response(state, obj, rid, request)
298
            yield input_ids
299

300
    async def _handle_batch_request(
301
        self,
302
        obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
303
        request: Optional[fastapi.Request] = None,
304
    ):
305
        batch_size = obj.batch_size
306
307
308
309
        if self.is_generation:
            parallel_sample_num = obj.parallel_sample_num

            if parallel_sample_num != 1:
310
                # Send prefill requests to cache the common prefix
311
312
313
314
315
316
317
318
319
320
321
322
323
324
                parallel_sample_num += 1
                input_id_result = [] if obj.input_ids is None else None
                for i in range(batch_size):
                    async for input_id in self._handle_single_request(
                        obj, request, index=i, is_cache_for_prefill=True
                    ):
                        if input_id_result is not None:
                            input_id_result.append(input_id)
                if input_id_result is not None and len(input_id_result) > 1:
                    obj.input_ids = input_id_result
                elif input_id_result is not None:
                    obj.input_ids = input_id_result[0]
        else:
            parallel_sample_num = 1
yichuan~'s avatar
yichuan~ committed
325

326
        # First send out all requests
327
        generators = []
328
329
330
        for i in range(batch_size):
            for j in range(parallel_sample_num):
                if j == 0 and parallel_sample_num != 1:
331
                    continue
332
333
                index = i * parallel_sample_num + j
                if parallel_sample_num != 1:
334
                    # Here when using parallel sampling we should consider prefill stage so the index is :  j + i * (parallel_sample_num-1) + batch_size - 1
335
336
337
338
                    index += batch_size - 1 - i
                rid = obj.rid[index]
                if parallel_sample_num == 1:
                    ## select operation
339
340
341
342
343
344
345
346
                    if hasattr(obj, "conv"):
                        # reward model
                        conv = obj.conv[i]
                        input_text = self.tokenizer.apply_chat_template(
                            conv, tokenize=False
                        )
                        input_ids = self.tokenizer.encode(input_text)
                    elif obj.input_ids is None:
347
                        input_text = obj.text[i]
348
                        input_ids = self.tokenizer.encode(input_text)
349
350
351
352
                    else:
                        input_text = None
                        input_ids = obj.input_ids[i]
                else:
yichuan~'s avatar
yichuan~ committed
353
                    assert obj.input_ids is not None
354
                    if batch_size == 1:
yichuan~'s avatar
yichuan~ committed
355
                        input_text = None
356
357
                        input_ids = obj.input_ids
                    else:
yichuan~'s avatar
yichuan~ committed
358
                        input_text = None
359
360
                        input_ids = obj.input_ids[i]
                sampling_params = self._get_sampling_params(obj.sampling_params[index])
361

362
                if self.is_generation:
363
364
                    image_inputs = await self.image_processor.process_images_async(
                        obj.image_data[index], obj
365
366
367
368
369
370
                    )

                    tokenized_obj = TokenizedGenerateReqInput(
                        rid,
                        input_text,
                        input_ids,
Liangsheng Yin's avatar
Liangsheng Yin committed
371
                        image_inputs,
372
373
374
375
376
                        sampling_params,
                        obj.return_logprob[index],
                        obj.logprob_start_len[index],
                        obj.top_logprobs_num[index],
                        obj.stream,
377
378
379
380
381
                        (
                            obj.lora_path[index]
                            if isinstance(obj.lora_path, list)
                            else obj.lora_path
                        ),
382
                    )
383
                elif isinstance(obj, EmbeddingReqInput):
384
385
386
387
388
389
                    tokenized_obj = TokenizedEmbeddingReqInput(
                        rid,
                        input_text,
                        input_ids,
                        sampling_params,
                    )
390
391
392
393
394
395
396
397
                else:
                    assert isinstance(obj, RewardReqInput)
                    tokenized_obj = TokenizedRewardReqInput(
                        rid,
                        input_text,
                        input_ids,
                        sampling_params,
                    )
398
                self.send_to_scheduler.send_pyobj(tokenized_obj)
Lianmin Zheng's avatar
Lianmin Zheng committed
399
400

                event = asyncio.Event()
401
                state = ReqState([], False, event)
Lianmin Zheng's avatar
Lianmin Zheng committed
402
                self.rid_to_state[rid] = state
403

404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
                generators.append(
                    self._wait_for_response(
                        state,
                        obj,
                        rid,
                        request,
                        index=index,
                        response_index=len(generators),
                    )
                )

        # Then process the responses based on streaming option
        is_stream = hasattr(obj, "stream") and obj.stream

        tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
419
        output_list = [None] * len(tasks)
420

421
        # Recv results
422
423
424
425
        while tasks:
            done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)

            for task in done:
426
                cur_index = tasks.index(task)
427
428
429
430
431
432
433

                try:
                    result = task.result()

                    if is_stream:
                        yield result
                    else:
434
                        output_list[result["index"]] = result
435

436
437
                    tasks[cur_index] = asyncio.create_task(
                        generators[cur_index].__anext__()
Liangsheng Yin's avatar
Liangsheng Yin committed
438
                    )
439
                except StopAsyncIteration:
440
441
                    del generators[cur_index]
                    del tasks[cur_index]
442
443
444

        if not is_stream:
            yield output_list
445

446
    def _validate_input_length(self, input_ids: List[int]):
447
448
449
450
451
452
        if len(input_ids) >= self.context_len:
            raise ValueError(
                f"The input ({len(input_ids)} tokens) is longer than the "
                f"model's context length ({self.context_len} tokens)."
            )

453
    def _get_sampling_params(self, sampling_params_data: dict):
454
455
456
457
458
459
        sampling_params = SamplingParams(**sampling_params_data)
        if sampling_params.max_new_tokens != 0:
            sampling_params.normalize(self.tokenizer)
            sampling_params.verify()
        return sampling_params

460
461
462
    async def _wait_for_response(
        self,
        state: ReqState,
463
        obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput],
464
        rid: str,
465
466
        request: Optional[fastapi.Request] = None,
        index: Optional[int] = None,
467
        response_index: int = 0,
468
    ):
469
470
        while True:
            try:
471
                await asyncio.wait_for(state.event.wait(), timeout=4)
472
473
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
474
475
                    for rid in [obj.rid] if obj.is_single else obj.rid:
                        self.abort_request(rid)
476
477
478
                    raise ValueError(f"Abort request {rid}")
                continue

479
480
481
            if self.is_generation:
                out = self.convert_logprob_style(
                    state.out_list[-1],
482
483
484
485
486
487
                    obj.return_logprob if index is None else obj.return_logprob[index],
                    (
                        obj.top_logprobs_num
                        if index is None
                        else obj.top_logprobs_num[index]
                    ),
488
489
                    obj.return_text_in_logprobs,
                )
490
            else:  # isinstance(obj, (EmbeddingReqInput, RewardReqInput))
491
                out = state.out_list[-1]
492

493
494
            out["index"] = response_index

Ying Sheng's avatar
Ying Sheng committed
495
            # Log requests
496
            if self.server_args.log_requests and state.finished:
497
                logger.info(f"in={obj}, out={out}")
498
499
500
501
502
503
504

            state.out_list = []
            if state.finished:
                del self.rid_to_state[rid]
                yield out
                break

505
            state.event.clear()
506
507
            yield out

508
509
510
511
512
    async def _wait_for_cache_prefill_response(
        self,
        state: ReqState,
        obj: GenerateReqInput,
        rid: str,
513
        request: Optional[fastapi.Request] = None,
514
    ):
515
516
517
518
519
520
521
522
523
524
525
526
527
        while True:
            try:
                await asyncio.wait_for(state.event.wait(), timeout=4)
                break
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
                    for rid in obj.rid:
                        self.abort_request(rid)
                    raise ValueError(f"Abort request {rid}")
                continue

        assert state.finished
        del self.rid_to_state[rid]
Lianmin Zheng's avatar
Lianmin Zheng committed
528

529
530
    def flush_cache(self):
        req = FlushCacheReq()
531
        self.send_to_scheduler.send_pyobj(req)
Liangsheng Yin's avatar
Liangsheng Yin committed
532

533
534
535
536
537
    def abort_request(self, rid: str):
        if rid not in self.rid_to_state:
            return
        del self.rid_to_state[rid]
        req = AbortReq(rid)
538
        self.send_to_scheduler.send_pyobj(req)
539

540
541
542
    async def update_weights(
        self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
    ):
543
544
545
546
547
548
549
550
551
552
553
554
        if self.to_create_loop:
            self.create_handle_loop()

        # default the load format to the server_args
        if obj.load_format is None:
            obj.load_format = self.server_args.load_format

        if not self.model_update_lock.locked():
            async with self.model_update_lock:
                # wait for the previous generation requests to finish
                while len(self.rid_to_state) > 0:
                    await asyncio.sleep(0)
555
                self.send_to_scheduler.send_pyobj(obj)
556
557
558
559
560
561
562
563
564
565
                self.model_update_result = asyncio.Future()
                result = await self.model_update_result
                if result.success:
                    self.server_args.model_path = obj.model_path
                    self.server_args.load_format = obj.load_format
                    self.model_path = obj.model_path
            return result.success, result.message
        else:
            return False, "Another update is in progress. Please try again later."

Lianmin Zheng's avatar
Lianmin Zheng committed
566
    def create_abort_task(self, obj: GenerateReqInput):
567
568
569
570
571
572
        # Abort the request if the client is disconnected.
        async def abort_request():
            await asyncio.sleep(3)
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
573
                for rid in obj.rid:
574
575
576
577
578
579
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

580
    def create_handle_loop(self):
581
582
583
        if not self.to_create_loop:
            return

Lianmin Zheng's avatar
Lianmin Zheng committed
584
585
586
587
588
        self.to_create_loop = False
        loop = asyncio.get_event_loop()
        loop.create_task(self.handle_loop())

    async def handle_loop(self):
589
590
        """The event loop that handles requests"""

Lianmin Zheng's avatar
Lianmin Zheng committed
591
        while True:
592
593
594
595
596
597
598
599
            recv_obj: Union[
                BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
            ] = await self.recv_from_detokenizer.recv_pyobj()

            if isinstance(recv_obj, UpdateWeightReqOutput):
                self.model_update_result.set_result(recv_obj)
                continue

600
601
602
            assert isinstance(
                recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
            ), f"Unexpected obj received: {type(recv_obj)}"
603

604
605
606
607
608
609
            for i, rid in enumerate(recv_obj.rids):
                state = self.rid_to_state.get(rid, None)
                if state is None:
                    continue

                recv_obj.meta_info[i]["id"] = rid
610
611
612
613
614
                if isinstance(recv_obj, BatchStrOut):
                    out_dict = {
                        "text": recv_obj.output_strs[i],
                        "meta_info": recv_obj.meta_info[i],
                    }
615
616
617
618
619
620
621
622
623
                elif isinstance(recv_obj, BatchTokenIDOut):
                    read_start = 0 if i == 0 else recv_obj.read_offsets[i - 1]
                    out_dict = {
                        "token_ids": recv_obj.decode_ids[
                            read_start : recv_obj.read_offsets[i]
                        ],
                        "meta_info": recv_obj.meta_info[i],
                    }

624
625
626
627
628
629
                else:
                    assert isinstance(recv_obj, BatchEmbeddingOut)
                    out_dict = {
                        "embedding": recv_obj.embeddings[i],
                        "meta_info": recv_obj.meta_info[i],
                    }
630
631
632
                state.out_list.append(out_dict)
                state.finished = recv_obj.finished_reason[i] is not None
                state.event.set()
633

Liangsheng Yin's avatar
Liangsheng Yin committed
634
    def convert_logprob_style(
635
636
637
638
639
        self,
        ret: dict,
        return_logprob: bool,
        top_logprobs_num: int,
        return_text_in_logprobs: bool,
Liangsheng Yin's avatar
Liangsheng Yin committed
640
    ):
641
        if return_logprob:
642
643
            ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
                ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
644
            )
645
646
            ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
                ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
647
            )
648
649

            if top_logprobs_num > 0:
650
                ret["meta_info"]["input_top_logprobs"] = (
zhyncs's avatar
zhyncs committed
651
                    self.detokenize_top_logprobs_tokens(
652
                        ret["meta_info"]["input_top_logprobs"],
zhyncs's avatar
zhyncs committed
653
654
                        return_text_in_logprobs,
                    )
655
                )
656
                ret["meta_info"]["output_top_logprobs"] = (
zhyncs's avatar
zhyncs committed
657
                    self.detokenize_top_logprobs_tokens(
658
                        ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
zhyncs's avatar
zhyncs committed
659
                    )
660
                )
661
662
        return ret

663
664
665
    def detokenize_logprob_tokens(
        self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
    ):
666
        # TODO(lianmin): This should run on DetokenizerManager
667
668
669
        if not decode_to_text:
            return [(logprob, token_id, None) for logprob, token_id in token_logprobs]

670
        assert self.tokenizer is not None
671
672
673
674
675
676
677
        token_ids = [tid for _, tid in token_logprobs]
        token_texts = self.tokenizer.batch_decode(token_ids)
        return [
            (logprob, token_id, token_text)
            for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
        ]

678
    def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
679
680
681
682
683
684
685
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
        for i, token_top_logprobs in enumerate(top_logprobs):
            if token_top_logprobs:
                top_logprobs[i] = self.detokenize_logprob_tokens(
                    token_top_logprobs, decode_to_text
                )
686
        return top_logprobs