tokenizer_manager.py 24.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""TokenizerManager is a process that tokenizes the text."""
15

Lianmin Zheng's avatar
Lianmin Zheng committed
16
import asyncio
17
import copy
Lianmin Zheng's avatar
Lianmin Zheng committed
18
import dataclasses
19
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import os
21
22
import signal
import sys
23
import time
24
import uuid
25
from typing import Dict, List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
26

27
import fastapi
Lianmin Zheng's avatar
Lianmin Zheng committed
28
29
30
import uvloop
import zmq
import zmq.asyncio
31
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
32

33
34
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
35
36
37
38
from sglang.srt.managers.image_processor import (
    get_dummy_image_processor,
    get_image_processor,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
39
from sglang.srt.managers.io_struct import (
40
    AbortReq,
41
    BatchEmbeddingOut,
Lianmin Zheng's avatar
Lianmin Zheng committed
42
    BatchStrOut,
43
    BatchTokenIDOut,
44
    CloseSessionReqInput,
45
    EmbeddingReqInput,
46
    FlushCacheReq,
Lianmin Zheng's avatar
Lianmin Zheng committed
47
    GenerateReqInput,
48
49
    OpenSessionReqInput,
    OpenSessionReqOutput,
50
    ProfileReq,
51
    TokenizedEmbeddingReqInput,
Lianmin Zheng's avatar
Lianmin Zheng committed
52
    TokenizedGenerateReqInput,
53
54
    UpdateWeightReqInput,
    UpdateWeightReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
55
)
56
from sglang.srt.metrics.collector import TokenizerMetricsCollector
57
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
58
from sglang.srt.server_args import PortArgs, ServerArgs
59
from sglang.srt.utils import get_zmq_socket, kill_process_tree
Lianmin Zheng's avatar
Lianmin Zheng committed
60
61
62

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

63
64
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
65
66
67

@dataclasses.dataclass
class ReqState:
68
69
    """Store the state a request."""

Lianmin Zheng's avatar
Lianmin Zheng committed
70
71
72
73
    out_list: List
    finished: bool
    event: asyncio.Event

74
75
76
77
    # For metrics
    created_time: float
    first_token_time: Optional[float] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
78
79

class TokenizerManager:
80
81
    """TokenizerManager is a process that tokenizes the text."""

Lianmin Zheng's avatar
Lianmin Zheng committed
82
83
84
85
86
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
87
        # Parse args
Liangsheng Yin's avatar
Liangsheng Yin committed
88
        self.server_args = server_args
89
        self.enable_metrics = server_args.enable_metrics
Liangsheng Yin's avatar
Liangsheng Yin committed
90

91
        # Init inter-process communication
Lianmin Zheng's avatar
Lianmin Zheng committed
92
        context = zmq.asyncio.Context(2)
93
94
95
96
97
98
        self.recv_from_detokenizer = get_zmq_socket(
            context, zmq.PULL, port_args.tokenizer_ipc_name
        )
        self.send_to_scheduler = get_zmq_socket(
            context, zmq.PUSH, port_args.scheduler_input_ipc_name
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
99

100
        # Read model args
Lianmin Zheng's avatar
Lianmin Zheng committed
101
        self.model_path = server_args.model_path
102
        self.served_model_name = server_args.served_model_name
103
104
        self.model_config = ModelConfig(
            server_args.model_path,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
105
            trust_remote_code=server_args.trust_remote_code,
106
107
108
            context_length=server_args.context_length,
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
109
        )
110

111
112
113
        self.is_generation = self.model_config.is_generation
        self.context_len = self.model_config.context_len

114
115
        # Create image processor placeholder
        self.image_processor = get_dummy_image_processor()
Lianmin Zheng's avatar
Lianmin Zheng committed
116

117
        # Create tokenizer
118
119
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
Lianmin Zheng's avatar
Lianmin Zheng committed
120
        else:
121
            if self.model_config.is_multimodal:
122
123
124
125
126
127
128
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
129

130
131
                # We want to parallelize the image pre-processing so we create an executor for it
                self.image_processor = get_image_processor(
132
                    self.model_config.hf_config, server_args, self.processor
133
134
135
136
137
138
139
                )
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
140

141
        # Store states
Lianmin Zheng's avatar
Lianmin Zheng committed
142
        self.to_create_loop = True
143
        self.rid_to_state: Dict[str, ReqState] = {}
Lianmin Zheng's avatar
Lianmin Zheng committed
144

145
        # For update model weights
146
147
148
        self.model_update_lock = asyncio.Lock()
        self.model_update_result = None

149
150
151
        # For session info
        self.session_futures = {}  # session_id -> asyncio event

152
153
154
        # Others
        self.gracefully_exit = False

155
156
157
158
159
160
161
162
163
        # Metrics
        if self.enable_metrics:
            self.metrics_collector = TokenizerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    # TODO: Add lora name/path in the future,
                },
            )

164
    async def generate_request(
165
        self,
166
        obj: Union[GenerateReqInput, EmbeddingReqInput],
167
        request: Optional[fastapi.Request] = None,
168
    ):
169
170
        created_time = time.time()

Lianmin Zheng's avatar
Lianmin Zheng committed
171
        if self.to_create_loop:
172
            self.create_handle_loop()
Lianmin Zheng's avatar
Lianmin Zheng committed
173

174
        while self.model_update_lock.locked():
175
            await asyncio.sleep(0.001)
176

177
178
        if isinstance(obj, EmbeddingReqInput) and self.is_generation:
            raise ValueError(
179
180
                "This model does not appear to be an embedding model by default. "
                "Please add `--is-embedding` when launching the server or try another model."
181
182
            )

183
        obj.normalize_batch_and_arguments()
184
        is_single = obj.is_single
185
        if is_single:
186
187
            tokenized_obj = await self._tokenize_one_request(obj)
            self.send_to_scheduler.send_pyobj(tokenized_obj)
188
            async for response in self._wait_one_response(obj, request, created_time):
189
190
                yield response
        else:
191
192
193
            async for response in self._handle_batch_request(
                obj, request, created_time
            ):
194
                yield response
195

196
    async def _tokenize_one_request(
197
        self,
198
        obj: Union[GenerateReqInput, EmbeddingReqInput],
199
    ):
200
201
        """Tokenize one request."""
        # Tokenize
Rin Intachuen's avatar
Rin Intachuen committed
202
        input_embeds = None
203
        input_text = obj.text
Rin Intachuen's avatar
Rin Intachuen committed
204
205
206
207
208
209
210
211
212
213
        if obj.input_embeds is not None:
            if not self.server_args.disable_radix_cache:
                raise ValueError(
                    "input_embeds is provided while disable_radix_cache is False. "
                    "Please add `--disable-radix-cach` when you launch the server "
                    "if you want to use input_embeds as inputs."
                )
            input_embeds = obj.input_embeds
            input_ids = obj.input_ids
        elif obj.input_ids is None:
214
215
216
217
218
            input_ids = self.tokenizer.encode(input_text)
        else:
            input_ids = obj.input_ids

        if self.is_generation:
219
            image_inputs: Dict = await self.image_processor.process_images_async(
220
                obj.image_data, input_text or input_ids, obj
221
            )
222
223
            if image_inputs and "input_ids" in image_inputs:
                input_ids = image_inputs["input_ids"]
224
225
226
            return_logprob = obj.return_logprob
            logprob_start_len = obj.logprob_start_len
            top_logprobs_num = obj.top_logprobs_num
227
228
            session_id = obj.session[0] if obj.session else None
            session_rid = obj.session[1] if obj.session else None
Lianmin Zheng's avatar
Lianmin Zheng committed
229

Rin Intachuen's avatar
Rin Intachuen committed
230
        if obj.input_ids is not None and len(input_ids) >= self.context_len:
231
232
233
234
235
236
237
238
239
240
241
242
            raise ValueError(
                f"The input ({len(input_ids)} tokens) is longer than the "
                f"model's context length ({self.context_len} tokens)."
            )

        # Parse sampling parameters
        sampling_params = SamplingParams(**obj.sampling_params)
        sampling_params.normalize(self.tokenizer)
        sampling_params.verify()

        # Build return object
        if isinstance(obj, GenerateReqInput):
243
            tokenized_obj = TokenizedGenerateReqInput(
244
                obj.rid,
245
246
                input_text,
                input_ids,
Liangsheng Yin's avatar
Liangsheng Yin committed
247
                image_inputs,
248
249
250
251
252
                sampling_params,
                return_logprob,
                logprob_start_len,
                top_logprobs_num,
                obj.stream,
Rin Intachuen's avatar
Rin Intachuen committed
253
254
                lora_path=obj.lora_path,
                input_embeds=input_embeds,
255
256
                session_id=session_id,
                session_rid=session_rid,
257
            )
258
        elif isinstance(obj, EmbeddingReqInput):
259
            tokenized_obj = TokenizedEmbeddingReqInput(
260
                obj.rid,
261
262
263
264
                input_text,
                input_ids,
                sampling_params,
            )
265

266
        return tokenized_obj
267

268
    async def _wait_one_response(
269
        self,
270
        obj: Union[GenerateReqInput, EmbeddingReqInput],
271
        request: Optional[fastapi.Request] = None,
272
        created_time: Optional[float] = None,
273
    ):
274
        """Wait for the response of one request."""
275
        event = asyncio.Event()
276
        state = ReqState([], False, event, created_time=created_time)
277
        self.rid_to_state[obj.rid] = state
278

279
280
        while True:
            try:
281
                await asyncio.wait_for(state.event.wait(), timeout=4)
282
283
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
284
285
                    self.abort_request(obj.rid)
                    raise ValueError(f"Abort request {obj.rid}")
286
287
                continue

288
            if isinstance(obj, GenerateReqInput):
289
290
                out = self.convert_logprob_style(
                    state.out_list[-1],
291
292
                    obj.return_logprob,
                    obj.top_logprobs_num,
293
294
                    obj.return_text_in_logprobs,
                )
295
            else:  # isinstance(obj, (EmbeddingReqInput,))
296
                out = state.out_list[-1]
297
298
299

            state.out_list = []
            if state.finished:
300
                if self.server_args.log_requests:
301
                    # Log requests
302
                    logger.info(f"in={obj}, out={out}")
303
                del self.rid_to_state[obj.rid]
304
305
306
                yield out
                break

307
            state.event.clear()
308
309
            yield out

310
311
312
313
    async def _handle_batch_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
314
        created_time: Optional[float] = None,
315
316
317
318
319
320
321
322
323
324
325
    ):
        batch_size = obj.batch_size

        generators = []
        rids = []
        if getattr(obj, "parallel_sample_num", 1) == 1:
            # Send all requests
            for i in range(batch_size):
                tmp_obj = obj[i]
                tokenized_obj = await self._tokenize_one_request(tmp_obj)
                self.send_to_scheduler.send_pyobj(tokenized_obj)
326
327
328
                generators.append(
                    self._wait_one_response(tmp_obj, request, created_time)
                )
329
330
331
332
333
334
                rids.append(tmp_obj.rid)
        else:
            # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.

            # Tokenize all requests
            objs = [obj[i] for i in range(batch_size)]
Chayenne's avatar
Chayenne committed
335
336
337
            tokenized_objs = await asyncio.gather(
                *(self._tokenize_one_request(obj) for obj in objs)
            )
338
339
340
341
342
343
344
345
346
347

            # Cache the common prefix for parallel sampling
            for i in range(batch_size):
                tmp_obj = copy.copy(objs[i])
                tokenized_obj = copy.copy(tokenized_objs[i])
                tokenized_obj.rid = tmp_obj.regenerate_rid()
                tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
                tokenized_obj.sampling_params.max_new_tokens = 0
                tokenized_obj.stream = False
                self.send_to_scheduler.send_pyobj(tokenized_obj)
348
349
350
                await self._wait_one_response(
                    tmp_obj, request, created_time
                ).__anext__()
351
352
353
354
355
356
357
358

            # Expand requests, assign new rids for them, and send them
            for i in range(batch_size):
                for _ in range(obj.parallel_sample_num):
                    tmp_obj = copy.copy(objs[i])
                    tokenized_obj = copy.copy(tokenized_objs[i])
                    tokenized_obj.rid = tmp_obj.regenerate_rid()
                    self.send_to_scheduler.send_pyobj(tokenized_obj)
359
360
361
                    generators.append(
                        self._wait_one_response(tmp_obj, request, created_time)
                    )
362
363
364
365
366
367
368
369
370
371
372
                    rids.append(tmp_obj.rid)

        # Wait for all requests
        is_stream = hasattr(obj, "stream") and obj.stream
        if not is_stream:
            outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
            yield outputs
        else:
            rid_to_index = {rid: i for i, rid in enumerate(rids)}
            task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
            while task_map:
Chayenne's avatar
Chayenne committed
373
374
375
                done, _ = await asyncio.wait(
                    task_map.keys(), return_when=asyncio.FIRST_COMPLETED
                )
376
377
378
379
380
381
382
383
384
385
386
387

                for task in done:
                    gen = task_map.pop(task)
                    try:
                        result = task.result()
                        result["index"] = rid_to_index[result["meta_info"]["id"]]
                        yield result
                        new_task = asyncio.create_task(gen.__anext__())
                        task_map[new_task] = gen
                    except StopAsyncIteration:
                        pass

388
389
    def flush_cache(self):
        req = FlushCacheReq()
390
        self.send_to_scheduler.send_pyobj(req)
Liangsheng Yin's avatar
Liangsheng Yin committed
391

392
393
394
395
396
    def abort_request(self, rid: str):
        if rid not in self.rid_to_state:
            return
        del self.rid_to_state[rid]
        req = AbortReq(rid)
397
        self.send_to_scheduler.send_pyobj(req)
398

399
400
401
402
403
404
405
406
    def start_profile(self):
        req = ProfileReq.START_PROFILE
        self.send_to_scheduler.send_pyobj(req)

    def stop_profile(self):
        req = ProfileReq.STOP_PROFILE
        self.send_to_scheduler.send_pyobj(req)

407
408
409
    async def update_weights(
        self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
    ):
410
411
412
413
414
415
416
417
        if self.to_create_loop:
            self.create_handle_loop()

        # default the load format to the server_args
        if obj.load_format is None:
            obj.load_format = self.server_args.load_format

        if not self.model_update_lock.locked():
418

Byron Hsu's avatar
Byron Hsu committed
419
420
            async with self.model_update_lock:
                # wait for the previous generation requests to finish
Lianmin Zheng's avatar
Lianmin Zheng committed
421
422
423
424
425
426
                for i in range(3):
                    while len(self.rid_to_state) > 0:
                        await asyncio.sleep(0.001)
                    # FIXME: We add some sleep here to avoid some race conditions.
                    # We can use a read-write lock as a better fix.
                    await asyncio.sleep(0.01)
Byron Hsu's avatar
Byron Hsu committed
427
428
429
430
                self.send_to_scheduler.send_pyobj(obj)
                self.model_update_result = asyncio.Future()

                if self.server_args.dp_size == 1:
431
432
433
434
435
                    result = await self.model_update_result
                    if result.success:
                        self.server_args.model_path = obj.model_path
                        self.server_args.load_format = obj.load_format
                        self.model_path = obj.model_path
Byron Hsu's avatar
Byron Hsu committed
436
                    return result.success, result.message
Chayenne's avatar
Chayenne committed
437
                else:  # self.server_args.dp_size > 1
438
439
440
441
442
443
444
445
446
447
                    self.model_update_tmp = []
                    result = await self.model_update_result

                    all_success = all([r.success for r in result])
                    if all_success is True:
                        self.server_args.model_path = obj.model_path
                        self.server_args.load_format = obj.load_format
                        self.model_path = obj.model_path
                    all_message = [r.message for r in result]
                    all_message = " | ".join(all_message)
Byron Hsu's avatar
Byron Hsu committed
448
                    return all_success, all_message
449

450
451
452
        else:
            return False, "Another update is in progress. Please try again later."

453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
    async def open_session(
        self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
    ):
        if self.to_create_loop:
            self.create_handle_loop()

        session_id = uuid.uuid4().hex
        obj.session_id = session_id
        self.send_to_scheduler.send_pyobj(obj)
        self.session_futures[session_id] = asyncio.Future()
        session_id = await self.session_futures[session_id]
        del self.session_futures[session_id]
        return session_id

    async def close_session(
        self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
    ):
        assert not self.to_create_loop, "close session should not be the first request"
        await self.send_to_scheduler.send_pyobj(obj)

Lianmin Zheng's avatar
Lianmin Zheng committed
473
    def create_abort_task(self, obj: GenerateReqInput):
474
475
        # Abort the request if the client is disconnected.
        async def abort_request():
476
            await asyncio.sleep(1)
477
478
479
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
480
                for rid in obj.rid:
481
482
483
484
485
486
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

487
    def create_handle_loop(self):
488
489
490
        if not self.to_create_loop:
            return

Lianmin Zheng's avatar
Lianmin Zheng committed
491
492
493
494
        self.to_create_loop = False
        loop = asyncio.get_event_loop()
        loop.create_task(self.handle_loop())

495
496
497
498
499
500
501
502
503
504
505
506
        signal_handler = SignalHandler(self)
        loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
        loop.create_task(self.sigterm_watchdog())

    async def sigterm_watchdog(self):
        while not self.gracefully_exit:
            await asyncio.sleep(60)

        # drain requests
        while True:
            remain_num_req = len(self.rid_to_state)
            logger.info(
507
                f"Gracefully exiting... remaining number of requests {remain_num_req}"
508
509
510
511
512
513
            )
            if remain_num_req > 0:
                await asyncio.sleep(5)
            else:
                break

514
        kill_process_tree(os.getpid(), include_parent=True)
515
        sys.exit(0)
516

Lianmin Zheng's avatar
Lianmin Zheng committed
517
    async def handle_loop(self):
518
519
        """The event loop that handles requests"""

Lianmin Zheng's avatar
Lianmin Zheng committed
520
        while True:
521
522
523
524
525
            recv_obj: Union[
                BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
            ] = await self.recv_from_detokenizer.recv_pyobj()

            if isinstance(recv_obj, UpdateWeightReqOutput):
526
527
                if self.server_args.dp_size == 1:
                    self.model_update_result.set_result(recv_obj)
Chayenne's avatar
Chayenne committed
528
                else:  # self.server_args.dp_size > 1
529
530
531
532
                    self.model_update_tmp.append(recv_obj)
                    # set future if the all results are recevied
                    if len(self.model_update_tmp) == self.server_args.dp_size:
                        self.model_update_result.set_result(self.model_update_tmp)
533
                continue
534
535
536
537
538
            elif isinstance(recv_obj, OpenSessionReqOutput):
                self.session_futures[recv_obj.session_id].set_result(
                    recv_obj.session_id
                )
                continue
539

540
541
542
            assert isinstance(
                recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
            ), f"Unexpected obj received: {type(recv_obj)}"
543

544
545
546
547
548
549
            for i, rid in enumerate(recv_obj.rids):
                state = self.rid_to_state.get(rid, None)
                if state is None:
                    continue

                recv_obj.meta_info[i]["id"] = rid
550
551
552
553
554
                if isinstance(recv_obj, BatchStrOut):
                    out_dict = {
                        "text": recv_obj.output_strs[i],
                        "meta_info": recv_obj.meta_info[i],
                    }
555
556
                elif isinstance(recv_obj, BatchTokenIDOut):
                    out_dict = {
557
                        "token_ids": recv_obj.output_ids[i],
558
559
                        "meta_info": recv_obj.meta_info[i],
                    }
560
561
562
563
564
565
                else:
                    assert isinstance(recv_obj, BatchEmbeddingOut)
                    out_dict = {
                        "embedding": recv_obj.embeddings[i],
                        "meta_info": recv_obj.meta_info[i],
                    }
566
567
568
                state.out_list.append(out_dict)
                state.finished = recv_obj.finished_reason[i] is not None
                state.event.set()
569

570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
                if self.enable_metrics:
                    completion_tokens = recv_obj.meta_info[i]["completion_tokens"]

                    if state.first_token_time is None:
                        state.first_token_time = time.time()
                        self.metrics_collector.observe_time_to_first_token(
                            state.first_token_time - state.created_time
                        )
                    else:
                        if completion_tokens >= 2:
                            self.metrics_collector.observe_time_per_output_token(
                                (time.time() - state.first_token_time)
                                / (completion_tokens - 1)
                            )

                    if state.finished:
                        self.metrics_collector.inc_prompt_tokens(
                            recv_obj.meta_info[i]["prompt_tokens"]
                        )
                        self.metrics_collector.inc_generation_tokens(completion_tokens)
                        self.metrics_collector.observe_e2e_request_latency(
                            time.time() - state.created_time
                        )
                        if completion_tokens >= 1:
                            self.metrics_collector.observe_time_per_output_token(
                                (time.time() - state.created_time) / completion_tokens
                            )

Liangsheng Yin's avatar
Liangsheng Yin committed
598
    def convert_logprob_style(
599
600
601
602
603
        self,
        ret: dict,
        return_logprob: bool,
        top_logprobs_num: int,
        return_text_in_logprobs: bool,
Liangsheng Yin's avatar
Liangsheng Yin committed
604
    ):
605
        if return_logprob:
606
607
            ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
                ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
608
            )
609
610
            ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
                ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
611
            )
612
613

            if top_logprobs_num > 0:
614
                ret["meta_info"]["input_top_logprobs"] = (
zhyncs's avatar
zhyncs committed
615
                    self.detokenize_top_logprobs_tokens(
616
                        ret["meta_info"]["input_top_logprobs"],
zhyncs's avatar
zhyncs committed
617
618
                        return_text_in_logprobs,
                    )
619
                )
620
                ret["meta_info"]["output_top_logprobs"] = (
zhyncs's avatar
zhyncs committed
621
                    self.detokenize_top_logprobs_tokens(
622
                        ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
zhyncs's avatar
zhyncs committed
623
                    )
624
                )
625
626
        return ret

627
628
629
    def detokenize_logprob_tokens(
        self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
    ):
630
        # TODO(lianmin): This should run on DetokenizerManager
631
632
633
        if not decode_to_text:
            return [(logprob, token_id, None) for logprob, token_id in token_logprobs]

634
        assert self.tokenizer is not None
635
636
637
638
        token_ids = [tid for _, tid in token_logprobs]
        token_texts = self.tokenizer.batch_decode(token_ids)
        return [
            (logprob, token_id, token_text)
639
            for (logprob, token_id), token_text in zip(token_logprobs, token_texts)
640
641
        ]

642
    def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
643
644
645
646
647
648
649
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
        for i, token_top_logprobs in enumerate(top_logprobs):
            if token_top_logprobs:
                top_logprobs[i] = self.detokenize_logprob_tokens(
                    token_top_logprobs, decode_to_text
                )
650
        return top_logprobs
651
652
653
654
655
656
657
658
659
660
661


class SignalHandler:
    def __init__(self, tokenizer_manager):
        self.tokenizer_manager = tokenizer_manager

    def signal_handler(self, signum=None, frame=None):
        logger.warning(
            f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
        )
        self.tokenizer_manager.gracefully_exit = True