tokenizer_manager.py 25.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""TokenizerManager is a process that tokenizes the text."""
15

Lianmin Zheng's avatar
Lianmin Zheng committed
16
import asyncio
17
import copy
Lianmin Zheng's avatar
Lianmin Zheng committed
18
import dataclasses
19
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import os
21
22
import signal
import sys
23
import time
24
import uuid
25
from typing import Dict, List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
26

27
import fastapi
Lianmin Zheng's avatar
Lianmin Zheng committed
28
29
30
import uvloop
import zmq
import zmq.asyncio
31
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
32

33
34
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
35
36
37
38
from sglang.srt.managers.image_processor import (
    get_dummy_image_processor,
    get_image_processor,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
39
from sglang.srt.managers.io_struct import (
40
    AbortReq,
41
    BatchEmbeddingOut,
Lianmin Zheng's avatar
Lianmin Zheng committed
42
    BatchStrOut,
43
    BatchTokenIDOut,
44
    CloseSessionReqInput,
45
    EmbeddingReqInput,
46
    FlushCacheReq,
Lianmin Zheng's avatar
Lianmin Zheng committed
47
    GenerateReqInput,
48
49
    GetMemPoolSizeReq,
    GetMemPoolSizeReqOutput,
50
51
    OpenSessionReqInput,
    OpenSessionReqOutput,
52
    ProfileReq,
53
    TokenizedEmbeddingReqInput,
Lianmin Zheng's avatar
Lianmin Zheng committed
54
    TokenizedGenerateReqInput,
55
56
    UpdateWeightReqInput,
    UpdateWeightReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
57
)
58
from sglang.srt.metrics.collector import TokenizerMetricsCollector
59
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
60
from sglang.srt.server_args import PortArgs, ServerArgs
61
from sglang.srt.utils import get_zmq_socket, kill_child_process
Lianmin Zheng's avatar
Lianmin Zheng committed
62
63
64

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

65
66
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
67
68
69

@dataclasses.dataclass
class ReqState:
70
71
    """Store the state a request."""

Lianmin Zheng's avatar
Lianmin Zheng committed
72
73
74
75
    out_list: List
    finished: bool
    event: asyncio.Event

76
77
78
79
    # For metrics
    created_time: float
    first_token_time: Optional[float] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
80
81

class TokenizerManager:
82
83
    """TokenizerManager is a process that tokenizes the text."""

Lianmin Zheng's avatar
Lianmin Zheng committed
84
85
86
87
88
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
89
        # Parse args
Liangsheng Yin's avatar
Liangsheng Yin committed
90
        self.server_args = server_args
91
        self.enable_metrics = server_args.enable_metrics
Liangsheng Yin's avatar
Liangsheng Yin committed
92

93
        # Init inter-process communication
Lianmin Zheng's avatar
Lianmin Zheng committed
94
        context = zmq.asyncio.Context(2)
95
96
97
98
99
100
        self.recv_from_detokenizer = get_zmq_socket(
            context, zmq.PULL, port_args.tokenizer_ipc_name
        )
        self.send_to_scheduler = get_zmq_socket(
            context, zmq.PUSH, port_args.scheduler_input_ipc_name
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
101

102
        # Read model args
Lianmin Zheng's avatar
Lianmin Zheng committed
103
        self.model_path = server_args.model_path
104
        self.served_model_name = server_args.served_model_name
105
106
        self.model_config = ModelConfig(
            server_args.model_path,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
107
            trust_remote_code=server_args.trust_remote_code,
108
109
110
            context_length=server_args.context_length,
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
111
        )
112

113
114
115
        self.is_generation = self.model_config.is_generation
        self.context_len = self.model_config.context_len

116
117
        # Create image processor placeholder
        self.image_processor = get_dummy_image_processor()
Lianmin Zheng's avatar
Lianmin Zheng committed
118

119
        # Create tokenizer
120
121
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
Lianmin Zheng's avatar
Lianmin Zheng committed
122
        else:
123
            if self.model_config.is_multimodal:
124
125
126
127
128
129
130
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
131

132
133
                # We want to parallelize the image pre-processing so we create an executor for it
                self.image_processor = get_image_processor(
134
                    self.model_config.hf_config, server_args, self.processor
135
136
137
138
139
140
141
                )
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
142

143
        # Store states
Lianmin Zheng's avatar
Lianmin Zheng committed
144
        self.to_create_loop = True
145
        self.rid_to_state: Dict[str, ReqState] = {}
Lianmin Zheng's avatar
Lianmin Zheng committed
146

147
        # For update model weights
148
149
150
        self.model_update_lock = asyncio.Lock()
        self.model_update_result = None

151
152
153
        # For session info
        self.session_futures = {}  # session_id -> asyncio event

154
155
156
        # Others
        self.gracefully_exit = False

157
158
159
160
161
162
163
164
165
        # Metrics
        if self.enable_metrics:
            self.metrics_collector = TokenizerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    # TODO: Add lora name/path in the future,
                },
            )

166
    async def generate_request(
167
        self,
168
        obj: Union[GenerateReqInput, EmbeddingReqInput],
169
        request: Optional[fastapi.Request] = None,
170
    ):
171
172
        created_time = time.time()

Lianmin Zheng's avatar
Lianmin Zheng committed
173
        if self.to_create_loop:
174
            self.create_handle_loop()
Lianmin Zheng's avatar
Lianmin Zheng committed
175

176
        while self.model_update_lock.locked():
177
            await asyncio.sleep(0.001)
178

179
180
        if isinstance(obj, EmbeddingReqInput) and self.is_generation:
            raise ValueError(
181
182
                "This model does not appear to be an embedding model by default. "
                "Please add `--is-embedding` when launching the server or try another model."
183
184
            )

185
        obj.normalize_batch_and_arguments()
186
        is_single = obj.is_single
187
        if is_single:
188
189
            tokenized_obj = await self._tokenize_one_request(obj)
            self.send_to_scheduler.send_pyobj(tokenized_obj)
190
            async for response in self._wait_one_response(obj, request, created_time):
191
192
                yield response
        else:
193
194
195
            async for response in self._handle_batch_request(
                obj, request, created_time
            ):
196
                yield response
197

198
    async def _tokenize_one_request(
199
        self,
200
        obj: Union[GenerateReqInput, EmbeddingReqInput],
201
    ):
202
203
        """Tokenize one request."""
        # Tokenize
Rin Intachuen's avatar
Rin Intachuen committed
204
        input_embeds = None
205
        input_text = obj.text
Rin Intachuen's avatar
Rin Intachuen committed
206
207
208
209
210
211
212
213
214
215
        if obj.input_embeds is not None:
            if not self.server_args.disable_radix_cache:
                raise ValueError(
                    "input_embeds is provided while disable_radix_cache is False. "
                    "Please add `--disable-radix-cach` when you launch the server "
                    "if you want to use input_embeds as inputs."
                )
            input_embeds = obj.input_embeds
            input_ids = obj.input_ids
        elif obj.input_ids is None:
216
217
218
219
220
            input_ids = self.tokenizer.encode(input_text)
        else:
            input_ids = obj.input_ids

        if self.is_generation:
221
            image_inputs = await self.image_processor.process_images_async(
222
                obj.image_data, input_text or input_ids, obj
223
            )
224
225
            if image_inputs and "input_ids" in image_inputs:
                input_ids = image_inputs["input_ids"]
226
227
228
            return_logprob = obj.return_logprob
            logprob_start_len = obj.logprob_start_len
            top_logprobs_num = obj.top_logprobs_num
229
230
            session_id = obj.session[0] if obj.session else None
            session_rid = obj.session[1] if obj.session else None
Lianmin Zheng's avatar
Lianmin Zheng committed
231

Rin Intachuen's avatar
Rin Intachuen committed
232
        if obj.input_ids is not None and len(input_ids) >= self.context_len:
233
234
235
236
237
238
239
240
241
242
243
244
            raise ValueError(
                f"The input ({len(input_ids)} tokens) is longer than the "
                f"model's context length ({self.context_len} tokens)."
            )

        # Parse sampling parameters
        sampling_params = SamplingParams(**obj.sampling_params)
        sampling_params.normalize(self.tokenizer)
        sampling_params.verify()

        # Build return object
        if isinstance(obj, GenerateReqInput):
245
            tokenized_obj = TokenizedGenerateReqInput(
246
                obj.rid,
247
248
                input_text,
                input_ids,
Liangsheng Yin's avatar
Liangsheng Yin committed
249
                image_inputs,
250
251
252
253
254
                sampling_params,
                return_logprob,
                logprob_start_len,
                top_logprobs_num,
                obj.stream,
Rin Intachuen's avatar
Rin Intachuen committed
255
256
                lora_path=obj.lora_path,
                input_embeds=input_embeds,
257
258
                session_id=session_id,
                session_rid=session_rid,
259
            )
260
        elif isinstance(obj, EmbeddingReqInput):
261
            tokenized_obj = TokenizedEmbeddingReqInput(
262
                obj.rid,
263
264
265
266
                input_text,
                input_ids,
                sampling_params,
            )
267

268
        return tokenized_obj
269

270
    async def _wait_one_response(
271
        self,
272
        obj: Union[GenerateReqInput, EmbeddingReqInput],
273
        request: Optional[fastapi.Request] = None,
274
        created_time: Optional[float] = None,
275
    ):
276
        """Wait for the response of one request."""
277
        event = asyncio.Event()
278
        state = ReqState([], False, event, created_time=created_time)
279
        self.rid_to_state[obj.rid] = state
280

281
282
        while True:
            try:
283
                await asyncio.wait_for(state.event.wait(), timeout=4)
284
285
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
286
287
                    self.abort_request(obj.rid)
                    raise ValueError(f"Abort request {obj.rid}")
288
289
                continue

290
            if isinstance(obj, GenerateReqInput):
291
292
                out = self.convert_logprob_style(
                    state.out_list[-1],
293
294
                    obj.return_logprob,
                    obj.top_logprobs_num,
295
296
                    obj.return_text_in_logprobs,
                )
297
            else:  # isinstance(obj, (EmbeddingReqInput,))
298
                out = state.out_list[-1]
299
300
301

            state.out_list = []
            if state.finished:
302
                if self.server_args.log_requests:
303
                    # Log requests
304
                    logger.info(f"in={obj}, out={out}")
305
                del self.rid_to_state[obj.rid]
306
307
308
                yield out
                break

309
            state.event.clear()
310
311
            yield out

312
313
314
315
    async def _handle_batch_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
316
        created_time: Optional[float] = None,
317
318
319
320
321
322
323
324
325
326
327
    ):
        batch_size = obj.batch_size

        generators = []
        rids = []
        if getattr(obj, "parallel_sample_num", 1) == 1:
            # Send all requests
            for i in range(batch_size):
                tmp_obj = obj[i]
                tokenized_obj = await self._tokenize_one_request(tmp_obj)
                self.send_to_scheduler.send_pyobj(tokenized_obj)
328
329
330
                generators.append(
                    self._wait_one_response(tmp_obj, request, created_time)
                )
331
332
333
334
335
336
                rids.append(tmp_obj.rid)
        else:
            # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.

            # Tokenize all requests
            objs = [obj[i] for i in range(batch_size)]
Chayenne's avatar
Chayenne committed
337
338
339
            tokenized_objs = await asyncio.gather(
                *(self._tokenize_one_request(obj) for obj in objs)
            )
340
341
342
343
344
345
346
347
348
349

            # Cache the common prefix for parallel sampling
            for i in range(batch_size):
                tmp_obj = copy.copy(objs[i])
                tokenized_obj = copy.copy(tokenized_objs[i])
                tokenized_obj.rid = tmp_obj.regenerate_rid()
                tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
                tokenized_obj.sampling_params.max_new_tokens = 0
                tokenized_obj.stream = False
                self.send_to_scheduler.send_pyobj(tokenized_obj)
350
351
352
                await self._wait_one_response(
                    tmp_obj, request, created_time
                ).__anext__()
353
354
355
356
357
358
359
360

            # Expand requests, assign new rids for them, and send them
            for i in range(batch_size):
                for _ in range(obj.parallel_sample_num):
                    tmp_obj = copy.copy(objs[i])
                    tokenized_obj = copy.copy(tokenized_objs[i])
                    tokenized_obj.rid = tmp_obj.regenerate_rid()
                    self.send_to_scheduler.send_pyobj(tokenized_obj)
361
362
363
                    generators.append(
                        self._wait_one_response(tmp_obj, request, created_time)
                    )
364
365
366
367
368
369
370
371
372
373
374
                    rids.append(tmp_obj.rid)

        # Wait for all requests
        is_stream = hasattr(obj, "stream") and obj.stream
        if not is_stream:
            outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
            yield outputs
        else:
            rid_to_index = {rid: i for i, rid in enumerate(rids)}
            task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
            while task_map:
Chayenne's avatar
Chayenne committed
375
376
377
                done, _ = await asyncio.wait(
                    task_map.keys(), return_when=asyncio.FIRST_COMPLETED
                )
378
379
380
381
382
383
384
385
386
387
388
389

                for task in done:
                    gen = task_map.pop(task)
                    try:
                        result = task.result()
                        result["index"] = rid_to_index[result["meta_info"]["id"]]
                        yield result
                        new_task = asyncio.create_task(gen.__anext__())
                        task_map[new_task] = gen
                    except StopAsyncIteration:
                        pass

390
391
    def flush_cache(self):
        req = FlushCacheReq()
392
        self.send_to_scheduler.send_pyobj(req)
Liangsheng Yin's avatar
Liangsheng Yin committed
393

394
395
396
397
398
    def abort_request(self, rid: str):
        if rid not in self.rid_to_state:
            return
        del self.rid_to_state[rid]
        req = AbortReq(rid)
399
        self.send_to_scheduler.send_pyobj(req)
400

401
402
403
404
405
406
407
408
    def start_profile(self):
        req = ProfileReq.START_PROFILE
        self.send_to_scheduler.send_pyobj(req)

    def stop_profile(self):
        req = ProfileReq.STOP_PROFILE
        self.send_to_scheduler.send_pyobj(req)

409
410
411
412
413
    async def get_memory_pool_size(self):
        if self.to_create_loop:
            self.create_handle_loop()

        req = GetMemPoolSizeReq()
Byron Hsu's avatar
Byron Hsu committed
414
415
416

        self.send_to_scheduler.send_pyobj(req)
        self.mem_pool_size = asyncio.Future()
417

418
        # FIXME: Each request should have its own future instead of using `self.mem_pool_size`.
419
420
        if self.server_args.dp_size == 1:
            res = await self.mem_pool_size
Byron Hsu's avatar
Byron Hsu committed
421
            return res.size
Chayenne's avatar
Chayenne committed
422
        else:  # self.server_args.dp_size > 1
423
424
425
            self.mem_pool_size_tmp = []
            res = await self.mem_pool_size
            ret = [r.size for r in res]
Byron Hsu's avatar
Byron Hsu committed
426
            return ret
427

428
429
430
    async def update_weights(
        self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
    ):
431
432
433
434
435
436
437
438
        if self.to_create_loop:
            self.create_handle_loop()

        # default the load format to the server_args
        if obj.load_format is None:
            obj.load_format = self.server_args.load_format

        if not self.model_update_lock.locked():
439

Byron Hsu's avatar
Byron Hsu committed
440
441
            async with self.model_update_lock:
                # wait for the previous generation requests to finish
Lianmin Zheng's avatar
Lianmin Zheng committed
442
443
444
445
446
447
                for i in range(3):
                    while len(self.rid_to_state) > 0:
                        await asyncio.sleep(0.001)
                    # FIXME: We add some sleep here to avoid some race conditions.
                    # We can use a read-write lock as a better fix.
                    await asyncio.sleep(0.01)
Byron Hsu's avatar
Byron Hsu committed
448
449
450
451
                self.send_to_scheduler.send_pyobj(obj)
                self.model_update_result = asyncio.Future()

                if self.server_args.dp_size == 1:
452
453
454
455
456
                    result = await self.model_update_result
                    if result.success:
                        self.server_args.model_path = obj.model_path
                        self.server_args.load_format = obj.load_format
                        self.model_path = obj.model_path
Byron Hsu's avatar
Byron Hsu committed
457
                    return result.success, result.message
Chayenne's avatar
Chayenne committed
458
                else:  # self.server_args.dp_size > 1
459
460
461
462
463
464
465
466
467
468
                    self.model_update_tmp = []
                    result = await self.model_update_result

                    all_success = all([r.success for r in result])
                    if all_success is True:
                        self.server_args.model_path = obj.model_path
                        self.server_args.load_format = obj.load_format
                        self.model_path = obj.model_path
                    all_message = [r.message for r in result]
                    all_message = " | ".join(all_message)
Byron Hsu's avatar
Byron Hsu committed
469
                    return all_success, all_message
470

471
472
473
        else:
            return False, "Another update is in progress. Please try again later."

474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
    async def open_session(
        self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
    ):
        if self.to_create_loop:
            self.create_handle_loop()

        session_id = uuid.uuid4().hex
        obj.session_id = session_id
        self.send_to_scheduler.send_pyobj(obj)
        self.session_futures[session_id] = asyncio.Future()
        session_id = await self.session_futures[session_id]
        del self.session_futures[session_id]
        return session_id

    async def close_session(
        self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
    ):
        assert not self.to_create_loop, "close session should not be the first request"
        await self.send_to_scheduler.send_pyobj(obj)

Lianmin Zheng's avatar
Lianmin Zheng committed
494
    def create_abort_task(self, obj: GenerateReqInput):
495
496
        # Abort the request if the client is disconnected.
        async def abort_request():
497
            await asyncio.sleep(1)
498
499
500
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
501
                for rid in obj.rid:
502
503
504
505
506
507
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

508
    def create_handle_loop(self):
509
510
511
        if not self.to_create_loop:
            return

Lianmin Zheng's avatar
Lianmin Zheng committed
512
513
514
515
        self.to_create_loop = False
        loop = asyncio.get_event_loop()
        loop.create_task(self.handle_loop())

516
517
518
519
520
521
522
523
524
525
526
527
        signal_handler = SignalHandler(self)
        loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
        loop.create_task(self.sigterm_watchdog())

    async def sigterm_watchdog(self):
        while not self.gracefully_exit:
            await asyncio.sleep(60)

        # drain requests
        while True:
            remain_num_req = len(self.rid_to_state)
            logger.info(
528
                f"Gracefully exiting... remaining number of requests {remain_num_req}"
529
530
531
532
533
534
535
            )
            if remain_num_req > 0:
                await asyncio.sleep(5)
            else:
                break

        kill_child_process(include_self=True)
536
        sys.exit(0)
537

Lianmin Zheng's avatar
Lianmin Zheng committed
538
    async def handle_loop(self):
539
540
        """The event loop that handles requests"""

Lianmin Zheng's avatar
Lianmin Zheng committed
541
        while True:
542
543
544
545
546
            recv_obj: Union[
                BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
            ] = await self.recv_from_detokenizer.recv_pyobj()

            if isinstance(recv_obj, UpdateWeightReqOutput):
547
548
                if self.server_args.dp_size == 1:
                    self.model_update_result.set_result(recv_obj)
Chayenne's avatar
Chayenne committed
549
                else:  # self.server_args.dp_size > 1
550
551
552
553
                    self.model_update_tmp.append(recv_obj)
                    # set future if the all results are recevied
                    if len(self.model_update_tmp) == self.server_args.dp_size:
                        self.model_update_result.set_result(self.model_update_tmp)
554
                continue
555
            elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
556
557
                if self.server_args.dp_size == 1:
                    self.mem_pool_size.set_result(recv_obj)
Chayenne's avatar
Chayenne committed
558
                else:  # self.sever_args.dp_size > 1
559
560
561
562
                    self.mem_pool_size_tmp.append(recv_obj)
                    # set future if the all results are received
                    if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
                        self.mem_pool_size.set_result(self.mem_pool_size_tmp)
563
                continue
564
565
566
567
568
            elif isinstance(recv_obj, OpenSessionReqOutput):
                self.session_futures[recv_obj.session_id].set_result(
                    recv_obj.session_id
                )
                continue
569

570
571
572
            assert isinstance(
                recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
            ), f"Unexpected obj received: {type(recv_obj)}"
573

574
575
576
577
578
579
            for i, rid in enumerate(recv_obj.rids):
                state = self.rid_to_state.get(rid, None)
                if state is None:
                    continue

                recv_obj.meta_info[i]["id"] = rid
580
581
582
583
584
                if isinstance(recv_obj, BatchStrOut):
                    out_dict = {
                        "text": recv_obj.output_strs[i],
                        "meta_info": recv_obj.meta_info[i],
                    }
585
586
                elif isinstance(recv_obj, BatchTokenIDOut):
                    out_dict = {
587
                        "token_ids": recv_obj.output_ids[i],
588
589
                        "meta_info": recv_obj.meta_info[i],
                    }
590
591
592
593
594
595
                else:
                    assert isinstance(recv_obj, BatchEmbeddingOut)
                    out_dict = {
                        "embedding": recv_obj.embeddings[i],
                        "meta_info": recv_obj.meta_info[i],
                    }
596
597
598
                state.out_list.append(out_dict)
                state.finished = recv_obj.finished_reason[i] is not None
                state.event.set()
599

600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
                if self.enable_metrics:
                    completion_tokens = recv_obj.meta_info[i]["completion_tokens"]

                    if state.first_token_time is None:
                        state.first_token_time = time.time()
                        self.metrics_collector.observe_time_to_first_token(
                            state.first_token_time - state.created_time
                        )
                    else:
                        if completion_tokens >= 2:
                            self.metrics_collector.observe_time_per_output_token(
                                (time.time() - state.first_token_time)
                                / (completion_tokens - 1)
                            )

                    if state.finished:
                        self.metrics_collector.inc_prompt_tokens(
                            recv_obj.meta_info[i]["prompt_tokens"]
                        )
                        self.metrics_collector.inc_generation_tokens(completion_tokens)
                        self.metrics_collector.observe_e2e_request_latency(
                            time.time() - state.created_time
                        )
                        if completion_tokens >= 1:
                            self.metrics_collector.observe_time_per_output_token(
                                (time.time() - state.created_time) / completion_tokens
                            )

Liangsheng Yin's avatar
Liangsheng Yin committed
628
    def convert_logprob_style(
629
630
631
632
633
        self,
        ret: dict,
        return_logprob: bool,
        top_logprobs_num: int,
        return_text_in_logprobs: bool,
Liangsheng Yin's avatar
Liangsheng Yin committed
634
    ):
635
        if return_logprob:
636
637
            ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
                ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
638
            )
639
640
            ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
                ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
641
            )
642
643

            if top_logprobs_num > 0:
644
                ret["meta_info"]["input_top_logprobs"] = (
zhyncs's avatar
zhyncs committed
645
                    self.detokenize_top_logprobs_tokens(
646
                        ret["meta_info"]["input_top_logprobs"],
zhyncs's avatar
zhyncs committed
647
648
                        return_text_in_logprobs,
                    )
649
                )
650
                ret["meta_info"]["output_top_logprobs"] = (
zhyncs's avatar
zhyncs committed
651
                    self.detokenize_top_logprobs_tokens(
652
                        ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
zhyncs's avatar
zhyncs committed
653
                    )
654
                )
655
656
        return ret

657
658
659
    def detokenize_logprob_tokens(
        self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
    ):
660
        # TODO(lianmin): This should run on DetokenizerManager
661
662
663
        if not decode_to_text:
            return [(logprob, token_id, None) for logprob, token_id in token_logprobs]

664
        assert self.tokenizer is not None
665
666
667
668
        token_ids = [tid for _, tid in token_logprobs]
        token_texts = self.tokenizer.batch_decode(token_ids)
        return [
            (logprob, token_id, token_text)
669
            for (logprob, token_id), token_text in zip(token_logprobs, token_texts)
670
671
        ]

672
    def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
673
674
675
676
677
678
679
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
        for i, token_top_logprobs in enumerate(top_logprobs):
            if token_top_logprobs:
                top_logprobs[i] = self.detokenize_logprob_tokens(
                    token_top_logprobs, decode_to_text
                )
680
        return top_logprobs
681
682
683
684
685
686
687
688
689
690
691


class SignalHandler:
    def __init__(self, tokenizer_manager):
        self.tokenizer_manager = tokenizer_manager

    def signal_handler(self, signum=None, frame=None):
        logger.warning(
            f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
        )
        self.tokenizer_manager.gracefully_exit = True