tokenizer_manager.py 30.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""TokenizerManager is a process that tokenizes the text."""
15

Lianmin Zheng's avatar
Lianmin Zheng committed
16
import asyncio
17
import copy
Lianmin Zheng's avatar
Lianmin Zheng committed
18
import dataclasses
19
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import os
21
22
import signal
import sys
23
import time
24
import uuid
Lianmin Zheng's avatar
Lianmin Zheng committed
25
from typing import Any, Dict, List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
26

27
import fastapi
Lianmin Zheng's avatar
Lianmin Zheng committed
28
29
30
import uvloop
import zmq
import zmq.asyncio
31
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
32

33
34
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
35
36
37
38
from sglang.srt.managers.image_processor import (
    get_dummy_image_processor,
    get_image_processor,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
39
from sglang.srt.managers.io_struct import (
40
    AbortReq,
41
    BatchEmbeddingOut,
Lianmin Zheng's avatar
Lianmin Zheng committed
42
    BatchStrOut,
43
    BatchTokenIDOut,
44
    CloseSessionReqInput,
45
    EmbeddingReqInput,
46
    FlushCacheReq,
Lianmin Zheng's avatar
Lianmin Zheng committed
47
    GenerateReqInput,
48
49
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
50
51
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
52
53
    OpenSessionReqInput,
    OpenSessionReqOutput,
54
    ProfileReq,
55
    TokenizedEmbeddingReqInput,
Lianmin Zheng's avatar
Lianmin Zheng committed
56
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
57
58
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
59
60
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
61
)
62
from sglang.srt.metrics.collector import TokenizerMetricsCollector
63
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
64
from sglang.srt.server_args import PortArgs, ServerArgs
65
from sglang.srt.utils import get_zmq_socket, kill_process_tree
Lianmin Zheng's avatar
Lianmin Zheng committed
66
67
68

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

69
70
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
71
72
73

@dataclasses.dataclass
class ReqState:
74
75
    """Store the state a request."""

Lianmin Zheng's avatar
Lianmin Zheng committed
76
77
78
    out_list: List
    finished: bool
    event: asyncio.Event
Lianmin Zheng's avatar
Lianmin Zheng committed
79
    obj: Any
Lianmin Zheng's avatar
Lianmin Zheng committed
80

81
82
83
84
    # For metrics
    created_time: float
    first_token_time: Optional[float] = None

Lianmin Zheng's avatar
Lianmin Zheng committed
85
86

class TokenizerManager:
87
88
    """TokenizerManager is a process that tokenizes the text."""

Lianmin Zheng's avatar
Lianmin Zheng committed
89
90
91
92
93
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
94
        # Parse args
Liangsheng Yin's avatar
Liangsheng Yin committed
95
        self.server_args = server_args
96
        self.enable_metrics = server_args.enable_metrics
Liangsheng Yin's avatar
Liangsheng Yin committed
97

98
        # Init inter-process communication
Lianmin Zheng's avatar
Lianmin Zheng committed
99
        context = zmq.asyncio.Context(2)
100
101
102
103
104
105
        self.recv_from_detokenizer = get_zmq_socket(
            context, zmq.PULL, port_args.tokenizer_ipc_name
        )
        self.send_to_scheduler = get_zmq_socket(
            context, zmq.PUSH, port_args.scheduler_input_ipc_name
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
106

107
        # Read model args
Lianmin Zheng's avatar
Lianmin Zheng committed
108
        self.model_path = server_args.model_path
109
        self.served_model_name = server_args.served_model_name
110
111
        self.model_config = ModelConfig(
            server_args.model_path,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
112
            trust_remote_code=server_args.trust_remote_code,
113
            revision=server_args.revision,
114
115
116
            context_length=server_args.context_length,
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
117
118
            dtype=server_args.dtype,
            quantization=server_args.quantization,
119
        )
120

121
122
123
        self.is_generation = self.model_config.is_generation
        self.context_len = self.model_config.context_len

124
125
        # Create image processor placeholder
        self.image_processor = get_dummy_image_processor()
Lianmin Zheng's avatar
Lianmin Zheng committed
126

127
        # Create tokenizer
128
129
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
Lianmin Zheng's avatar
Lianmin Zheng committed
130
        else:
131
            if self.model_config.is_multimodal:
132
133
134
135
136
137
138
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
139

140
141
                # We want to parallelize the image pre-processing so we create an executor for it
                self.image_processor = get_image_processor(
142
                    self.model_config.hf_config, server_args, self.processor
143
144
145
146
147
148
149
                )
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
150

151
        # Store states
Lianmin Zheng's avatar
Lianmin Zheng committed
152
        self.to_create_loop = True
153
        self.rid_to_state: Dict[str, ReqState] = {}
Lianmin Zheng's avatar
Lianmin Zheng committed
154

155
        # For update model weights
156
157
158
        self.model_update_lock = asyncio.Lock()
        self.model_update_result = None

159
160
161
        # For session info
        self.session_futures = {}  # session_id -> asyncio event

162
163
164
        # Others
        self.gracefully_exit = False

165
166
167
168
169
170
171
172
173
        # Metrics
        if self.enable_metrics:
            self.metrics_collector = TokenizerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    # TODO: Add lora name/path in the future,
                },
            )

174
    async def generate_request(
175
        self,
176
        obj: Union[GenerateReqInput, EmbeddingReqInput],
177
        request: Optional[fastapi.Request] = None,
178
    ):
179
180
        created_time = time.time()

Lianmin Zheng's avatar
Lianmin Zheng committed
181
        if self.to_create_loop:
182
            self.create_handle_loop()
Lianmin Zheng's avatar
Lianmin Zheng committed
183

184
        while self.model_update_lock.locked():
185
            await asyncio.sleep(0.001)
186

187
188
        if isinstance(obj, EmbeddingReqInput) and self.is_generation:
            raise ValueError(
189
190
                "This model does not appear to be an embedding model by default. "
                "Please add `--is-embedding` when launching the server or try another model."
191
192
            )

193
        obj.normalize_batch_and_arguments()
194
        is_single = obj.is_single
195
        if is_single:
196
197
            tokenized_obj = await self._tokenize_one_request(obj)
            self.send_to_scheduler.send_pyobj(tokenized_obj)
198
            async for response in self._wait_one_response(obj, request, created_time):
199
200
                yield response
        else:
201
202
203
            async for response in self._handle_batch_request(
                obj, request, created_time
            ):
204
                yield response
205

206
    async def _tokenize_one_request(
207
        self,
208
        obj: Union[GenerateReqInput, EmbeddingReqInput],
209
    ):
210
211
        """Tokenize one request."""
        # Tokenize
Rin Intachuen's avatar
Rin Intachuen committed
212
        input_embeds = None
213
        input_text = obj.text
Rin Intachuen's avatar
Rin Intachuen committed
214
215
216
217
218
219
220
221
222
223
        if obj.input_embeds is not None:
            if not self.server_args.disable_radix_cache:
                raise ValueError(
                    "input_embeds is provided while disable_radix_cache is False. "
                    "Please add `--disable-radix-cach` when you launch the server "
                    "if you want to use input_embeds as inputs."
                )
            input_embeds = obj.input_embeds
            input_ids = obj.input_ids
        elif obj.input_ids is None:
224
225
226
227
228
            input_ids = self.tokenizer.encode(input_text)
        else:
            input_ids = obj.input_ids

        if self.is_generation:
229
            # TODO: also support getting embeddings for multimodal models
230
            image_inputs: Dict = await self.image_processor.process_images_async(
231
                obj.image_data, input_text or input_ids, obj
232
            )
233
234
            if image_inputs and "input_ids" in image_inputs:
                input_ids = image_inputs["input_ids"]
235
236
237
            return_logprob = obj.return_logprob
            logprob_start_len = obj.logprob_start_len
            top_logprobs_num = obj.top_logprobs_num
238
239
            session_id = obj.session[0] if obj.session else None
            session_rid = obj.session[1] if obj.session else None
Lianmin Zheng's avatar
Lianmin Zheng committed
240

Rin Intachuen's avatar
Rin Intachuen committed
241
        if obj.input_ids is not None and len(input_ids) >= self.context_len:
242
243
244
245
246
247
248
249
250
251
252
253
            raise ValueError(
                f"The input ({len(input_ids)} tokens) is longer than the "
                f"model's context length ({self.context_len} tokens)."
            )

        # Parse sampling parameters
        sampling_params = SamplingParams(**obj.sampling_params)
        sampling_params.normalize(self.tokenizer)
        sampling_params.verify()

        # Build return object
        if isinstance(obj, GenerateReqInput):
254
            tokenized_obj = TokenizedGenerateReqInput(
255
                obj.rid,
256
257
                input_text,
                input_ids,
Liangsheng Yin's avatar
Liangsheng Yin committed
258
                image_inputs,
259
260
261
262
263
                sampling_params,
                return_logprob,
                logprob_start_len,
                top_logprobs_num,
                obj.stream,
Rin Intachuen's avatar
Rin Intachuen committed
264
265
                lora_path=obj.lora_path,
                input_embeds=input_embeds,
266
267
                session_id=session_id,
                session_rid=session_rid,
268
            )
269
        elif isinstance(obj, EmbeddingReqInput):
270
            tokenized_obj = TokenizedEmbeddingReqInput(
271
                obj.rid,
272
273
274
275
                input_text,
                input_ids,
                sampling_params,
            )
276

277
        return tokenized_obj
278

279
    async def _wait_one_response(
280
        self,
281
        obj: Union[GenerateReqInput, EmbeddingReqInput],
282
        request: Optional[fastapi.Request] = None,
283
        created_time: Optional[float] = None,
284
    ):
285
        """Wait for the response of one request."""
286
        event = asyncio.Event()
Lianmin Zheng's avatar
Lianmin Zheng committed
287
        state = ReqState([], False, event, obj, created_time=created_time)
288
        self.rid_to_state[obj.rid] = state
289

290
291
        while True:
            try:
292
                await asyncio.wait_for(state.event.wait(), timeout=4)
293
294
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
295
296
                    self.abort_request(obj.rid)
                    raise ValueError(f"Abort request {obj.rid}")
297
298
                continue

Lianmin Zheng's avatar
Lianmin Zheng committed
299
            out = state.out_list[-1]
300
301
302

            state.out_list = []
            if state.finished:
303
                if self.server_args.log_requests:
304
                    # Log requests
305
                    logger.info(f"in={obj}, out={out}")
306
                del self.rid_to_state[obj.rid]
307
308
309
                yield out
                break

310
            state.event.clear()
Lianmin Zheng's avatar
Lianmin Zheng committed
311
312
313
314
315
316
317

            if obj.stream:
                yield out
            else:
                if request is not None and await request.is_disconnected():
                    self.abort_request(obj.rid)
                    raise ValueError(f"Abort request {obj.rid}")
318

319
320
321
322
    async def _handle_batch_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
323
        created_time: Optional[float] = None,
324
325
326
327
328
329
330
331
332
333
334
    ):
        batch_size = obj.batch_size

        generators = []
        rids = []
        if getattr(obj, "parallel_sample_num", 1) == 1:
            # Send all requests
            for i in range(batch_size):
                tmp_obj = obj[i]
                tokenized_obj = await self._tokenize_one_request(tmp_obj)
                self.send_to_scheduler.send_pyobj(tokenized_obj)
335
336
337
                generators.append(
                    self._wait_one_response(tmp_obj, request, created_time)
                )
338
339
340
                rids.append(tmp_obj.rid)
        else:
            # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
Lianmin Zheng's avatar
Lianmin Zheng committed
341
342
343
344
345
346
            if batch_size > 128:
                logger.warning(
                    "Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
                    "The performance might be better if you just duplicate the requests n times or use "
                    "many threads to send them one by one with parallel sampling (n > 1)."
                )
347
348
349

            # Tokenize all requests
            objs = [obj[i] for i in range(batch_size)]
Chayenne's avatar
Chayenne committed
350
351
352
            tokenized_objs = await asyncio.gather(
                *(self._tokenize_one_request(obj) for obj in objs)
            )
353
354
355
356
357
358
359
360
361
362

            # Cache the common prefix for parallel sampling
            for i in range(batch_size):
                tmp_obj = copy.copy(objs[i])
                tokenized_obj = copy.copy(tokenized_objs[i])
                tokenized_obj.rid = tmp_obj.regenerate_rid()
                tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
                tokenized_obj.sampling_params.max_new_tokens = 0
                tokenized_obj.stream = False
                self.send_to_scheduler.send_pyobj(tokenized_obj)
363
364
365
                await self._wait_one_response(
                    tmp_obj, request, created_time
                ).__anext__()
366
367
368
369
370
371
372
373

            # Expand requests, assign new rids for them, and send them
            for i in range(batch_size):
                for _ in range(obj.parallel_sample_num):
                    tmp_obj = copy.copy(objs[i])
                    tokenized_obj = copy.copy(tokenized_objs[i])
                    tokenized_obj.rid = tmp_obj.regenerate_rid()
                    self.send_to_scheduler.send_pyobj(tokenized_obj)
374
375
376
                    generators.append(
                        self._wait_one_response(tmp_obj, request, created_time)
                    )
377
378
379
380
381
382
383
384
385
386
387
                    rids.append(tmp_obj.rid)

        # Wait for all requests
        is_stream = hasattr(obj, "stream") and obj.stream
        if not is_stream:
            outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
            yield outputs
        else:
            rid_to_index = {rid: i for i, rid in enumerate(rids)}
            task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
            while task_map:
Chayenne's avatar
Chayenne committed
388
389
390
                done, _ = await asyncio.wait(
                    task_map.keys(), return_when=asyncio.FIRST_COMPLETED
                )
391
392
393
394
395
396
397
398
399
400
401
402

                for task in done:
                    gen = task_map.pop(task)
                    try:
                        result = task.result()
                        result["index"] = rid_to_index[result["meta_info"]["id"]]
                        yield result
                        new_task = asyncio.create_task(gen.__anext__())
                        task_map[new_task] = gen
                    except StopAsyncIteration:
                        pass

403
404
    def flush_cache(self):
        req = FlushCacheReq()
405
        self.send_to_scheduler.send_pyobj(req)
Liangsheng Yin's avatar
Liangsheng Yin committed
406

407
408
409
410
411
    def abort_request(self, rid: str):
        if rid not in self.rid_to_state:
            return
        del self.rid_to_state[rid]
        req = AbortReq(rid)
412
        self.send_to_scheduler.send_pyobj(req)
413

414
415
416
417
418
419
420
421
    def start_profile(self):
        req = ProfileReq.START_PROFILE
        self.send_to_scheduler.send_pyobj(req)

    def stop_profile(self):
        req = ProfileReq.STOP_PROFILE
        self.send_to_scheduler.send_pyobj(req)

Chayenne's avatar
Chayenne committed
422
423
424
425
    async def update_weights_from_disk(
        self,
        obj: UpdateWeightFromDiskReqInput,
        request: Optional[fastapi.Request] = None,
426
    ):
427
428
429
430
431
432
433
434
        if self.to_create_loop:
            self.create_handle_loop()

        # default the load format to the server_args
        if obj.load_format is None:
            obj.load_format = self.server_args.load_format

        if not self.model_update_lock.locked():
435

Byron Hsu's avatar
Byron Hsu committed
436
437
            async with self.model_update_lock:
                # wait for the previous generation requests to finish
Lianmin Zheng's avatar
Lianmin Zheng committed
438
439
440
441
442
443
                for i in range(3):
                    while len(self.rid_to_state) > 0:
                        await asyncio.sleep(0.001)
                    # FIXME: We add some sleep here to avoid some race conditions.
                    # We can use a read-write lock as a better fix.
                    await asyncio.sleep(0.01)
Byron Hsu's avatar
Byron Hsu committed
444
445
446
447
                self.send_to_scheduler.send_pyobj(obj)
                self.model_update_result = asyncio.Future()

                if self.server_args.dp_size == 1:
448
449
450
451
452
                    result = await self.model_update_result
                    if result.success:
                        self.server_args.model_path = obj.model_path
                        self.server_args.load_format = obj.load_format
                        self.model_path = obj.model_path
Byron Hsu's avatar
Byron Hsu committed
453
                    return result.success, result.message
Chayenne's avatar
Chayenne committed
454
                else:  # self.server_args.dp_size > 1
455
456
457
458
459
460
461
462
463
464
                    self.model_update_tmp = []
                    result = await self.model_update_result

                    all_success = all([r.success for r in result])
                    if all_success is True:
                        self.server_args.model_path = obj.model_path
                        self.server_args.load_format = obj.load_format
                        self.model_path = obj.model_path
                    all_message = [r.message for r in result]
                    all_message = " | ".join(all_message)
Byron Hsu's avatar
Byron Hsu committed
465
                    return all_success, all_message
466

467
468
469
        else:
            return False, "Another update is in progress. Please try again later."

470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
    async def init_weights_update_group(
        self,
        obj: InitWeightsUpdateGroupReqInput,
        request: Optional[fastapi.Request] = None,
    ) -> bool:
        if self.to_create_loop:
            self.create_handle_loop()
        self.send_to_scheduler.send_pyobj(obj)

        self.init_weights_update_group_result = asyncio.Future()
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be 1 for init parameter update group"
        result = await self.init_weights_update_group_result
        return result.success, result.message

    async def update_weights_from_distributed(
        self,
        obj: UpdateWeightsFromDistributedReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        if self.to_create_loop:
            self.create_handle_loop()

        if not self.model_update_lock.locked():
            async with self.model_update_lock:
                self.send_to_scheduler.send_pyobj(obj)
                self.parameter_update_result = asyncio.Future()
                assert (
                    self.server_args.dp_size == 1
                ), "dp_size must be for update weights from distributed"
                result = await self.parameter_update_result
                return result.success, result.message
        else:
Lianmin Zheng's avatar
Lianmin Zheng committed
504
            logger.error("Another parameter update is in progress in tokenizer manager")
505
506
507
508
509
            return (
                False,
                "Another parameter update is in progress. Please try again later.",
            )

510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
    async def get_weights_by_name(
        self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
    ):
        if self.to_create_loop:
            self.create_handle_loop()

        self.send_to_scheduler.send_pyobj(obj)
        self.get_weights_by_name_result = asyncio.Future()
        if self.server_args.dp_size == 1:
            result = await self.get_weights_by_name_result
            return result.parameter
        else:
            self.get_weights_by_name_tmp = []
            result = await self.get_weights_by_name_result
            all_parameters = [r.parameter for r in result]
            return all_parameters

527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
    async def open_session(
        self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
    ):
        if self.to_create_loop:
            self.create_handle_loop()

        session_id = uuid.uuid4().hex
        obj.session_id = session_id
        self.send_to_scheduler.send_pyobj(obj)
        self.session_futures[session_id] = asyncio.Future()
        session_id = await self.session_futures[session_id]
        del self.session_futures[session_id]
        return session_id

    async def close_session(
        self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
    ):
        assert not self.to_create_loop, "close session should not be the first request"
        await self.send_to_scheduler.send_pyobj(obj)

Lianmin Zheng's avatar
Lianmin Zheng committed
547
    def create_abort_task(self, obj: GenerateReqInput):
548
549
        # Abort the request if the client is disconnected.
        async def abort_request():
550
            await asyncio.sleep(1)
551
552
553
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
554
                for rid in obj.rid:
555
556
557
558
559
560
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

561
    def create_handle_loop(self):
562
563
564
        if not self.to_create_loop:
            return

Lianmin Zheng's avatar
Lianmin Zheng committed
565
566
567
568
        self.to_create_loop = False
        loop = asyncio.get_event_loop()
        loop.create_task(self.handle_loop())

569
570
571
572
573
574
575
576
577
578
579
580
        signal_handler = SignalHandler(self)
        loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
        loop.create_task(self.sigterm_watchdog())

    async def sigterm_watchdog(self):
        while not self.gracefully_exit:
            await asyncio.sleep(60)

        # drain requests
        while True:
            remain_num_req = len(self.rid_to_state)
            logger.info(
581
                f"Gracefully exiting... remaining number of requests {remain_num_req}"
582
583
584
585
586
587
            )
            if remain_num_req > 0:
                await asyncio.sleep(5)
            else:
                break

588
        kill_process_tree(os.getpid(), include_parent=True)
589
        sys.exit(0)
590

Lianmin Zheng's avatar
Lianmin Zheng committed
591
    async def handle_loop(self):
592
593
        """The event loop that handles requests"""

Lianmin Zheng's avatar
Lianmin Zheng committed
594
        while True:
595
            recv_obj: Union[
Chayenne's avatar
Chayenne committed
596
597
598
599
                BatchStrOut,
                BatchEmbeddingOut,
                BatchTokenIDOut,
                UpdateWeightFromDiskReqOutput,
600
                UpdateWeightsFromDistributedReqOutput,
601
                GetWeightsByNameReqOutput,
602
                InitWeightsUpdateGroupReqOutput,
603
604
            ] = await self.recv_from_detokenizer.recv_pyobj()

Lianmin Zheng's avatar
Lianmin Zheng committed
605
606
607
608
609
610
            if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
                for i, rid in enumerate(recv_obj.rids):
                    state = self.rid_to_state.get(rid, None)
                    if state is None:
                        continue

Lianmin Zheng's avatar
Lianmin Zheng committed
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
                    meta_info = {
                        "id": rid,
                        "finish_reason": recv_obj.finished_reasons[i],
                        "prompt_tokens": recv_obj.prompt_tokens[i],
                    }

                    if getattr(state.obj, "return_logprob", False):
                        self.convert_logprob_style(
                            meta_info,
                            state.obj.top_logprobs_num,
                            state.obj.return_text_in_logprobs,
                            recv_obj,
                            i,
                        )

Lianmin Zheng's avatar
Lianmin Zheng committed
626
627
628
                    if isinstance(recv_obj, BatchStrOut):
                        out_dict = {
                            "text": recv_obj.output_strs[i],
Lianmin Zheng's avatar
Lianmin Zheng committed
629
630
631
632
633
                            "meta_info": {
                                **meta_info,
                                "completion_tokens": recv_obj.completion_tokens[i],
                                "cached_tokens": recv_obj.cached_tokens[i],
                            },
Lianmin Zheng's avatar
Lianmin Zheng committed
634
635
636
637
                        }
                    elif isinstance(recv_obj, BatchTokenIDOut):
                        out_dict = {
                            "token_ids": recv_obj.output_ids[i],
Lianmin Zheng's avatar
Lianmin Zheng committed
638
639
640
641
642
                            "meta_info": {
                                **meta_info,
                                "completion_tokens": recv_obj.completion_tokens[i],
                                "cached_tokens": recv_obj.cached_tokens[i],
                            },
Lianmin Zheng's avatar
Lianmin Zheng committed
643
644
645
646
647
                        }
                    else:
                        assert isinstance(recv_obj, BatchEmbeddingOut)
                        out_dict = {
                            "embedding": recv_obj.embeddings[i],
Lianmin Zheng's avatar
Lianmin Zheng committed
648
                            "meta_info": meta_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
649
650
                        }
                    state.out_list.append(out_dict)
Lianmin Zheng's avatar
Lianmin Zheng committed
651
                    state.finished = recv_obj.finished_reasons[i] is not None
Lianmin Zheng's avatar
Lianmin Zheng committed
652
653
654
                    state.event.set()

                    if self.enable_metrics:
Lianmin Zheng's avatar
Lianmin Zheng committed
655
656
657
658
659
                        completion_tokens = (
                            recv_obj.completion_tokens[i]
                            if recv_obj.completion_tokens
                            else 0
                        )
Lianmin Zheng's avatar
Lianmin Zheng committed
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674

                        if state.first_token_time is None:
                            state.first_token_time = time.time()
                            self.metrics_collector.observe_time_to_first_token(
                                state.first_token_time - state.created_time
                            )
                        else:
                            if completion_tokens >= 2:
                                self.metrics_collector.observe_time_per_output_token(
                                    (time.time() - state.first_token_time)
                                    / (completion_tokens - 1)
                                )

                        if state.finished:
                            self.metrics_collector.inc_prompt_tokens(
Lianmin Zheng's avatar
Lianmin Zheng committed
675
                                recv_obj.prompt_tokens[i]
Lianmin Zheng's avatar
Lianmin Zheng committed
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
                            )
                            self.metrics_collector.inc_generation_tokens(
                                completion_tokens
                            )
                            self.metrics_collector.observe_e2e_request_latency(
                                time.time() - state.created_time
                            )
                            if completion_tokens >= 1:
                                self.metrics_collector.observe_time_per_output_token(
                                    (time.time() - state.created_time)
                                    / completion_tokens
                                )
            elif isinstance(recv_obj, OpenSessionReqOutput):
                self.session_futures[recv_obj.session_id].set_result(
                    recv_obj.session_id
                )
            elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
693
694
                if self.server_args.dp_size == 1:
                    self.model_update_result.set_result(recv_obj)
Chayenne's avatar
Chayenne committed
695
                else:  # self.server_args.dp_size > 1
696
697
698
699
                    self.model_update_tmp.append(recv_obj)
                    # set future if the all results are recevied
                    if len(self.model_update_tmp) == self.server_args.dp_size:
                        self.model_update_result.set_result(self.model_update_tmp)
Lianmin Zheng's avatar
Lianmin Zheng committed
700
701
702
703
704
            elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
                assert (
                    self.server_args.dp_size == 1
                ), "dp_size must be 1 for init parameter update group"
                self.init_weights_update_group_result.set_result(recv_obj)
705
706
707
708
709
            elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
                assert (
                    self.server_args.dp_size == 1
                ), "dp_size must be 1 for update weights from distributed"
                self.parameter_update_result.set_result(recv_obj)
710
711
712
713
714
715
716
717
718
            elif isinstance(recv_obj, GetWeightsByNameReqOutput):
                if self.server_args.dp_size == 1:
                    self.get_weights_by_name_result.set_result(recv_obj)
                else:
                    self.get_weights_by_name_tmp.append(recv_obj)
                    if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
                        self.get_weights_by_name_result.set_result(
                            self.get_weights_by_name_tmp
                        )
Lianmin Zheng's avatar
Lianmin Zheng committed
719
720
            else:
                raise ValueError(f"Invalid object: {recv_obj=}")
721

Liangsheng Yin's avatar
Liangsheng Yin committed
722
    def convert_logprob_style(
723
        self,
Lianmin Zheng's avatar
Lianmin Zheng committed
724
        meta_info: dict,
725
726
        top_logprobs_num: int,
        return_text_in_logprobs: bool,
Lianmin Zheng's avatar
Lianmin Zheng committed
727
728
        recv_obj: BatchStrOut,
        recv_obj_index: int,
Liangsheng Yin's avatar
Liangsheng Yin committed
729
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
        meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
            recv_obj.input_token_logprobs_val[recv_obj_index],
            recv_obj.input_token_logprobs_idx[recv_obj_index],
            return_text_in_logprobs,
        )
        meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
            recv_obj.output_token_logprobs_val[recv_obj_index],
            recv_obj.output_token_logprobs_idx[recv_obj_index],
            return_text_in_logprobs,
        )
        meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[
            recv_obj_index
        ]

        if top_logprobs_num > 0:
            meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
                recv_obj.input_top_logprobs_val[recv_obj_index],
                recv_obj.input_top_logprobs_idx[recv_obj_index],
                return_text_in_logprobs,
749
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
750
751
752
753
            meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
                recv_obj.output_top_logprobs_val[recv_obj_index],
                recv_obj.output_top_logprobs_idx[recv_obj_index],
                return_text_in_logprobs,
754
            )
755

756
    def detokenize_logprob_tokens(
Lianmin Zheng's avatar
Lianmin Zheng committed
757
758
759
760
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
761
    ):
762
        if not decode_to_text:
Lianmin Zheng's avatar
Lianmin Zheng committed
763
764
765
766
767
768
769
770
            return [
                (logprob, token_id, None)
                for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
            ]
        else:
            assert self.tokenizer is not None
            token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
            return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))
771

Lianmin Zheng's avatar
Lianmin Zheng committed
772
773
774
775
776
777
    def detokenize_top_logprobs_tokens(
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
    ):
778
779
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
Lianmin Zheng's avatar
Lianmin Zheng committed
780
781
782
783
784
785
786
        ret = []
        for i in range(len(token_logprobs_val)):
            if token_logprobs_val[i]:
                ret.append(
                    self.detokenize_logprob_tokens(
                        token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
                    )
787
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
788
789
790
            else:
                ret.append(None)
        return ret
791
792
793
794
795
796
797
798
799
800
801


class SignalHandler:
    def __init__(self, tokenizer_manager):
        self.tokenizer_manager = tokenizer_manager

    def signal_handler(self, signum=None, frame=None):
        logger.warning(
            f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
        )
        self.tokenizer_manager.gracefully_exit = True