tokenizer_manager.py 37.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
Lianmin Zheng's avatar
Lianmin Zheng committed
14
"""TokenizerManager is a process that tokenizes the text."""
15

Lianmin Zheng's avatar
Lianmin Zheng committed
16
import asyncio
17
import copy
Lianmin Zheng's avatar
Lianmin Zheng committed
18
import dataclasses
19
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
20
import os
21
import pickle
22
23
import signal
import sys
24
import threading
25
import time
26
import uuid
27
from datetime import datetime
28
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
29

30
import fastapi
Lianmin Zheng's avatar
Lianmin Zheng committed
31
32
33
import uvloop
import zmq
import zmq.asyncio
34
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
35

36
from sglang.srt.aio_rwlock import RWLock
37
38
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
39
40
41
42
from sglang.srt.managers.image_processor import (
    get_dummy_image_processor,
    get_image_processor,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
43
from sglang.srt.managers.io_struct import (
44
    AbortReq,
45
    BatchEmbeddingOut,
Lianmin Zheng's avatar
Lianmin Zheng committed
46
    BatchStrOut,
47
    BatchTokenIDOut,
48
    CloseSessionReqInput,
49
    ConfigureLoggingReq,
50
    EmbeddingReqInput,
51
    FlushCacheReq,
Lianmin Zheng's avatar
Lianmin Zheng committed
52
    GenerateReqInput,
53
54
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
55
56
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
57
58
    OpenSessionReqInput,
    OpenSessionReqOutput,
59
    ProfileReq,
60
61
62
63
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
64
    SessionParams,
65
    TokenizedEmbeddingReqInput,
Lianmin Zheng's avatar
Lianmin Zheng committed
66
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
67
68
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
69
70
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
71
72
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
73
)
74
from sglang.srt.metrics.collector import TokenizerMetricsCollector
75
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
76
from sglang.srt.server_args import PortArgs, ServerArgs
77
78
79
80
81
from sglang.srt.utils import (
    dataclass_to_string_truncated,
    get_zmq_socket,
    kill_process_tree,
)
82
from sglang.utils import get_exception_traceback
Lianmin Zheng's avatar
Lianmin Zheng committed
83
84
85

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

86
87
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
88
89
90

@dataclasses.dataclass
class ReqState:
91
92
    """Store the state a request."""

Lianmin Zheng's avatar
Lianmin Zheng committed
93
94
95
    out_list: List
    finished: bool
    event: asyncio.Event
Lianmin Zheng's avatar
Lianmin Zheng committed
96
    obj: Any
Lianmin Zheng's avatar
Lianmin Zheng committed
97

98
99
100
101
    # For metrics
    created_time: float
    first_token_time: Optional[float] = None

102
103
104
    # For streaming output
    last_output_offset: int = 0

Lianmin Zheng's avatar
Lianmin Zheng committed
105
106

class TokenizerManager:
107
108
    """TokenizerManager is a process that tokenizes the text."""

Lianmin Zheng's avatar
Lianmin Zheng committed
109
110
111
112
113
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
114
        # Parse args
Liangsheng Yin's avatar
Liangsheng Yin committed
115
        self.server_args = server_args
116
        self.enable_metrics = server_args.enable_metrics
117
        self.log_requests = server_args.log_requests
Liangsheng Yin's avatar
Liangsheng Yin committed
118

119
        # Init inter-process communication
Lianmin Zheng's avatar
Lianmin Zheng committed
120
        context = zmq.asyncio.Context(2)
121
        self.recv_from_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
122
            context, zmq.PULL, port_args.tokenizer_ipc_name, True
123
124
        )
        self.send_to_scheduler = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
125
            context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
126
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
127

128
        # Read model args
Lianmin Zheng's avatar
Lianmin Zheng committed
129
        self.model_path = server_args.model_path
130
        self.served_model_name = server_args.served_model_name
131
132
        self.model_config = ModelConfig(
            server_args.model_path,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
133
            trust_remote_code=server_args.trust_remote_code,
134
            revision=server_args.revision,
135
136
137
            context_length=server_args.context_length,
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
138
139
            dtype=server_args.dtype,
            quantization=server_args.quantization,
140
        )
141

142
143
        self.is_generation = self.model_config.is_generation
        self.context_len = self.model_config.context_len
144
        self.image_token_id = self.model_config.image_token_id
145

146
147
        # Create image processor placeholder
        self.image_processor = get_dummy_image_processor()
Lianmin Zheng's avatar
Lianmin Zheng committed
148

149
        # Create tokenizer
150
151
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
Lianmin Zheng's avatar
Lianmin Zheng committed
152
        else:
153
            if self.model_config.is_multimodal:
154
155
156
157
158
159
160
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
                self.tokenizer = self.processor.tokenizer
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
161

162
163
                # We want to parallelize the image pre-processing so we create an executor for it
                self.image_processor = get_image_processor(
164
                    self.model_config.hf_config, server_args, self.processor
165
166
167
168
169
170
171
                )
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
172

173
        # Store states
Lianmin Zheng's avatar
Lianmin Zheng committed
174
        self.to_create_loop = True
175
        self.rid_to_state: Dict[str, ReqState] = {}
176
177
178
        self.dump_requests_folder = ""  # By default do not dump
        self.dump_requests_threshold = 1000
        self.dump_request_list: List[Tuple] = []
Lianmin Zheng's avatar
Lianmin Zheng committed
179

180
181
182
183
184
185
        # The event to notify the weight sync is finished.
        self.model_update_lock = RWLock()
        self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
            None
        )
        self.asyncio_tasks = set()
186

187
188
189
        # For session info
        self.session_futures = {}  # session_id -> asyncio event

190
191
        # Others
        self.gracefully_exit = False
192
193
194
195
196
197
        self.init_weights_update_group_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.update_weights_from_distributed_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
198
199
200
        self.update_weights_from_tensor_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
201
202
203
        self.get_weights_by_name_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
204
205
206
207
208
209
        self.release_memory_occupation_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.resume_memory_occupation_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
210

211
212
213
214
215
216
217
218
219
        # Metrics
        if self.enable_metrics:
            self.metrics_collector = TokenizerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    # TODO: Add lora name/path in the future,
                },
            )

220
    async def generate_request(
221
        self,
222
        obj: Union[GenerateReqInput, EmbeddingReqInput],
223
        request: Optional[fastapi.Request] = None,
224
    ):
225
226
        created_time = time.time()

227
        self.auto_create_handle_loop()
Lianmin Zheng's avatar
Lianmin Zheng committed
228

229
230
        if isinstance(obj, EmbeddingReqInput) and self.is_generation:
            raise ValueError(
231
232
                "This model does not appear to be an embedding model by default. "
                "Please add `--is-embedding` when launching the server or try another model."
233
234
            )

235
        obj.normalize_batch_and_arguments()
236

237
        if self.log_requests:
238
239
240
241
242
243
            logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")

        async with self.model_update_lock.reader_lock:
            is_single = obj.is_single
            if is_single:
                tokenized_obj = await self._tokenize_one_request(obj)
244
245
                self._send_one_request(obj, tokenized_obj, created_time)
                async for response in self._wait_one_response(obj, request):
246
247
248
249
250
251
                    yield response
            else:
                async for response in self._handle_batch_request(
                    obj, request, created_time
                ):
                    yield response
252

253
    async def _tokenize_one_request(
254
        self,
255
        obj: Union[GenerateReqInput, EmbeddingReqInput],
256
    ):
257
258
        """Tokenize one request."""
        # Tokenize
Rin Intachuen's avatar
Rin Intachuen committed
259
        input_embeds = None
260
        input_text = obj.text
Rin Intachuen's avatar
Rin Intachuen committed
261
262
263
264
        if obj.input_embeds is not None:
            if not self.server_args.disable_radix_cache:
                raise ValueError(
                    "input_embeds is provided while disable_radix_cache is False. "
265
                    "Please add `--disable-radix-cache` when you launch the server "
Rin Intachuen's avatar
Rin Intachuen committed
266
267
268
269
                    "if you want to use input_embeds as inputs."
                )
            input_embeds = obj.input_embeds
            input_ids = obj.input_ids
270
        elif obj.input_ids is not None:
271
            input_ids = obj.input_ids
272
273
274
275
276
277
278
279
        else:
            if self.tokenizer is None:
                raise ValueError(
                    "The engine initialized with skip_tokenizer_init=True cannot "
                    "accept text prompts. Please provide input_ids or re-initialize "
                    "the engine with skip_tokenizer_init=False."
                )
            input_ids = self.tokenizer.encode(input_text)
280
281

        if self.is_generation:
282
            # TODO: also support getting embeddings for multimodal models
283
            image_inputs: Dict = await self.image_processor.process_images_async(
284
                obj.image_data, input_text or input_ids, obj
285
            )
286
287
            if image_inputs and "input_ids" in image_inputs:
                input_ids = image_inputs["input_ids"]
288
289
290
            return_logprob = obj.return_logprob
            logprob_start_len = obj.logprob_start_len
            top_logprobs_num = obj.top_logprobs_num
291
292
293
            session_params = (
                SessionParams(**obj.session_params) if obj.session_params else None
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
294

295
296
        input_token_num = len(input_ids) if input_ids is not None else 0
        if input_token_num >= self.context_len:
297
            raise ValueError(
298
                f"The input ({input_token_num} tokens) is longer than the "
299
300
301
                f"model's context length ({self.context_len} tokens)."
            )

302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        if (
            obj.sampling_params.get("max_new_tokens") is not None
            and obj.sampling_params.get("max_new_tokens") + input_token_num
            >= self.context_len
        ):
            raise ValueError(
                f"Requested token count exceeds the model's maximum context length "
                f"of {self.context_len} tokens. You requested a total of "
                f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
                f"tokens: {input_token_num} tokens from the input messages and "
                f"{obj.sampling_params.get('max_new_tokens')} tokens for the "
                f"completion. Please reduce the number of tokens in the input "
                f"messages or the completion to fit within the limit."
            )

317
318
319
320
321
322
323
        # Parse sampling parameters
        sampling_params = SamplingParams(**obj.sampling_params)
        sampling_params.normalize(self.tokenizer)
        sampling_params.verify()

        # Build return object
        if isinstance(obj, GenerateReqInput):
324
            tokenized_obj = TokenizedGenerateReqInput(
325
                obj.rid,
326
327
                input_text,
                input_ids,
Liangsheng Yin's avatar
Liangsheng Yin committed
328
                image_inputs,
329
330
331
332
333
                sampling_params,
                return_logprob,
                logprob_start_len,
                top_logprobs_num,
                obj.stream,
Rin Intachuen's avatar
Rin Intachuen committed
334
335
                lora_path=obj.lora_path,
                input_embeds=input_embeds,
336
                session_params=session_params,
337
            )
338
        elif isinstance(obj, EmbeddingReqInput):
339
            tokenized_obj = TokenizedEmbeddingReqInput(
340
                obj.rid,
341
342
343
344
                input_text,
                input_ids,
                sampling_params,
            )
345

346
        return tokenized_obj
347

348
    def _send_one_request(
349
        self,
350
        obj: Union[GenerateReqInput, EmbeddingReqInput],
351
        tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
352
        created_time: Optional[float] = None,
353
    ):
354
        event = asyncio.Event()
Lianmin Zheng's avatar
Lianmin Zheng committed
355
        state = ReqState([], False, event, obj, created_time=created_time)
356
        self.rid_to_state[obj.rid] = state
357
358
359
360
361
362
363
364
365
        self.send_to_scheduler.send_pyobj(tokenized_obj)

    async def _wait_one_response(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
    ):
        """Wait for the response of one request."""
        state = self.rid_to_state[obj.rid]
366

367
368
        while True:
            try:
369
                await asyncio.wait_for(state.event.wait(), timeout=4)
370
371
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
372
373
                    self.abort_request(obj.rid)
                    raise ValueError(f"Abort request {obj.rid}")
374
375
                continue

Lianmin Zheng's avatar
Lianmin Zheng committed
376
            out = state.out_list[-1]
377
378
379

            state.out_list = []
            if state.finished:
380
                if self.log_requests:
381
382
                    msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
                    logger.info(msg)
383
                del self.rid_to_state[obj.rid]
384
385
386
                yield out
                break

387
            state.event.clear()
Lianmin Zheng's avatar
Lianmin Zheng committed
388
389
390
391
392
393
394

            if obj.stream:
                yield out
            else:
                if request is not None and await request.is_disconnected():
                    self.abort_request(obj.rid)
                    raise ValueError(f"Abort request {obj.rid}")
395

396
397
398
399
    async def _handle_batch_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
400
        created_time: Optional[float] = None,
401
402
403
404
405
406
407
408
409
410
    ):
        batch_size = obj.batch_size

        generators = []
        rids = []
        if getattr(obj, "parallel_sample_num", 1) == 1:
            # Send all requests
            for i in range(batch_size):
                tmp_obj = obj[i]
                tokenized_obj = await self._tokenize_one_request(tmp_obj)
411
412
                self._send_one_request(tmp_obj, tokenized_obj, created_time)
                generators.append(self._wait_one_response(tmp_obj, request))
413
414
415
                rids.append(tmp_obj.rid)
        else:
            # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
Lianmin Zheng's avatar
Lianmin Zheng committed
416
417
418
419
420
421
            if batch_size > 128:
                logger.warning(
                    "Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
                    "The performance might be better if you just duplicate the requests n times or use "
                    "many threads to send them one by one with parallel sampling (n > 1)."
                )
422
423
424

            # Tokenize all requests
            objs = [obj[i] for i in range(batch_size)]
Chayenne's avatar
Chayenne committed
425
426
427
            tokenized_objs = await asyncio.gather(
                *(self._tokenize_one_request(obj) for obj in objs)
            )
428
429
430
431
432
433
434
435
436

            # Cache the common prefix for parallel sampling
            for i in range(batch_size):
                tmp_obj = copy.copy(objs[i])
                tokenized_obj = copy.copy(tokenized_objs[i])
                tokenized_obj.rid = tmp_obj.regenerate_rid()
                tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
                tokenized_obj.sampling_params.max_new_tokens = 0
                tokenized_obj.stream = False
437
438
                self._send_one_request(tmp_obj, tokenized_obj, created_time)
                await self._wait_one_response(tmp_obj, request).__anext__()
439
440
441
442
443
444
445

            # Expand requests, assign new rids for them, and send them
            for i in range(batch_size):
                for _ in range(obj.parallel_sample_num):
                    tmp_obj = copy.copy(objs[i])
                    tokenized_obj = copy.copy(tokenized_objs[i])
                    tokenized_obj.rid = tmp_obj.regenerate_rid()
446
447
                    self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, request))
448
449
450
451
452
453
454
455
456
457
458
                    rids.append(tmp_obj.rid)

        # Wait for all requests
        is_stream = hasattr(obj, "stream") and obj.stream
        if not is_stream:
            outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
            yield outputs
        else:
            rid_to_index = {rid: i for i, rid in enumerate(rids)}
            task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
            while task_map:
Chayenne's avatar
Chayenne committed
459
460
461
                done, _ = await asyncio.wait(
                    task_map.keys(), return_when=asyncio.FIRST_COMPLETED
                )
462
463
464
465
466
467
468
469
470
471
472
473

                for task in done:
                    gen = task_map.pop(task)
                    try:
                        result = task.result()
                        result["index"] = rid_to_index[result["meta_info"]["id"]]
                        yield result
                        new_task = asyncio.create_task(gen.__anext__())
                        task_map[new_task] = gen
                    except StopAsyncIteration:
                        pass

474
475
    def flush_cache(self):
        req = FlushCacheReq()
476
        self.send_to_scheduler.send_pyobj(req)
Liangsheng Yin's avatar
Liangsheng Yin committed
477

478
479
480
481
482
    def abort_request(self, rid: str):
        if rid not in self.rid_to_state:
            return
        del self.rid_to_state[rid]
        req = AbortReq(rid)
483
        self.send_to_scheduler.send_pyobj(req)
484

485
486
487
488
489
490
491
492
    def start_profile(self):
        req = ProfileReq.START_PROFILE
        self.send_to_scheduler.send_pyobj(req)

    def stop_profile(self):
        req = ProfileReq.STOP_PROFILE
        self.send_to_scheduler.send_pyobj(req)

Chayenne's avatar
Chayenne committed
493
494
495
496
    async def update_weights_from_disk(
        self,
        obj: UpdateWeightFromDiskReqInput,
        request: Optional[fastapi.Request] = None,
497
    ) -> Tuple[bool, str]:
498
        self.auto_create_handle_loop()
499
500
501
502

        # default the load format to the server_args
        if obj.load_format is None:
            obj.load_format = self.server_args.load_format
503
        logger.info("Start update_weights. Load format=%s", obj.load_format)
504

505
506
507
508
509
        if True:
            # Hold the lock if it is not async. This means that weight sync
            # cannot run while requests are in progress.
            async with self.model_update_lock.writer_lock:
                return await self._wait_for_model_update_from_disk(obj)
510

511
512
    async def _wait_for_model_update_from_disk(
        self, obj: UpdateWeightFromDiskReqInput
513
    ) -> Tuple[bool, str]:
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
        self.send_to_scheduler.send_pyobj(obj)
        self.model_update_result = asyncio.Future()
        if self.server_args.dp_size == 1:
            result = await self.model_update_result
            if result.success:
                self.served_model_name = obj.model_path
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
            return result.success, result.message
        else:  # self.server_args.dp_size > 1
            self.model_update_tmp = []
            result = await self.model_update_result

            all_success = all([r.success for r in result])
            if all_success is True:
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
            all_message = [r.message for r in result]
            all_message = " | ".join(all_message)
            return all_success, all_message
536

537
538
539
540
    async def init_weights_update_group(
        self,
        obj: InitWeightsUpdateGroupReqInput,
        request: Optional[fastapi.Request] = None,
541
    ) -> Tuple[bool, str]:
542
        self.auto_create_handle_loop()
543
544
545
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be 1 for init parameter update group"
546
        result = (await self.init_weights_update_group_communicator(obj))[0]
547
548
549
550
551
552
        return result.success, result.message

    async def update_weights_from_distributed(
        self,
        obj: UpdateWeightsFromDistributedReqInput,
        request: Optional[fastapi.Request] = None,
553
    ) -> Tuple[bool, str]:
554
555
556
557
        self.auto_create_handle_loop()
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be for update weights from distributed"
558

559
560
561
        # This means that weight sync
        # cannot run while requests are in progress.
        async with self.model_update_lock.writer_lock:
562
            result = (await self.update_weights_from_distributed_communicator(obj))[0]
563
            return result.success, result.message
564

565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
    async def update_weights_from_tensor(
        self,
        obj: UpdateWeightsFromTensorReqInput,
        request: Optional[fastapi.Request] = None,
    ) -> Tuple[bool, str]:
        self.auto_create_handle_loop()
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be for update weights from distributed"

        # This means that weight sync
        # cannot run while requests are in progress.
        async with self.model_update_lock.writer_lock:
            result = (await self.update_weights_from_tensor_communicator(obj))[0]
            return result.success, result.message

581
582
583
    async def get_weights_by_name(
        self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
    ):
584
585
586
        self.auto_create_handle_loop()
        results = await self.get_weights_by_name_communicator(obj)
        all_parameters = [r.parameter for r in results]
587
        if self.server_args.dp_size == 1:
588
            return all_parameters[0]
589
590
591
        else:
            return all_parameters

592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
    async def release_memory_occupation(
        self,
        obj: ReleaseMemoryOccupationReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.release_memory_occupation_communicator(obj)

    async def resume_memory_occupation(
        self,
        obj: ResumeMemoryOccupationReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.resume_memory_occupation_communicator(obj)

608
609
610
    async def open_session(
        self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
    ):
611
        self.auto_create_handle_loop()
612

613
614
615
616
617
        if obj.session_id is None:
            obj.session_id = uuid.uuid4().hex
        elif obj.session_id in self.session_futures:
            return None

618
        self.send_to_scheduler.send_pyobj(obj)
619
620
621
622

        self.session_futures[obj.session_id] = asyncio.Future()
        session_id = await self.session_futures[obj.session_id]
        del self.session_futures[obj.session_id]
623
624
625
626
627
628
629
630
        return session_id

    async def close_session(
        self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
    ):
        assert not self.to_create_loop, "close session should not be the first request"
        await self.send_to_scheduler.send_pyobj(obj)

631
632
633
634
635
636
637
638
639
    def configure_logging(self, obj: ConfigureLoggingReq):
        if obj.log_requests is not None:
            self.log_requests = obj.log_requests
        if obj.dump_requests_folder is not None:
            self.dump_requests_folder = obj.dump_requests_folder
        if obj.dump_requests_threshold is not None:
            self.dump_requests_threshold = obj.dump_requests_threshold
        logging.info(f"Config logging: {obj=}")

Lianmin Zheng's avatar
Lianmin Zheng committed
640
    def create_abort_task(self, obj: GenerateReqInput):
641
642
        # Abort the request if the client is disconnected.
        async def abort_request():
643
            await asyncio.sleep(1)
644
645
646
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
647
                for rid in obj.rid:
648
649
650
651
652
653
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

654
    def auto_create_handle_loop(self):
655
656
657
        if not self.to_create_loop:
            return

Lianmin Zheng's avatar
Lianmin Zheng committed
658
659
        self.to_create_loop = False
        loop = asyncio.get_event_loop()
660
661
662
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.handle_loop))
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
663

664
665
666
667
668
669
670
671
672
673
674
        # We cannot add signal handler when the tokenizer manager is not in
        # the main thread due to the CPython limitation.
        if threading.current_thread() is threading.main_thread():
            signal_handler = SignalHandler(self)
            loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
        else:
            logger.warning(
                "Signal handler is not added because the tokenizer manager is "
                "not in the main thread. This disables graceful shutdown of the "
                "tokenizer manager when SIGTERM is received."
            )
675
676
677
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
        )
678
679
680

    async def sigterm_watchdog(self):
        while not self.gracefully_exit:
681
            await asyncio.sleep(5)
682

683
        # Drain requests
684
685
686
        while True:
            remain_num_req = len(self.rid_to_state)
            logger.info(
687
                f"Gracefully exiting... remaining number of requests {remain_num_req}"
688
689
690
691
692
693
            )
            if remain_num_req > 0:
                await asyncio.sleep(5)
            else:
                break

694
        kill_process_tree(os.getpid(), include_parent=True)
695
        sys.exit(0)
696

Lianmin Zheng's avatar
Lianmin Zheng committed
697
    async def handle_loop(self):
698
699
        """The event loop that handles requests"""

Lianmin Zheng's avatar
Lianmin Zheng committed
700
        while True:
701
            recv_obj: Union[
Chayenne's avatar
Chayenne committed
702
703
704
705
                BatchStrOut,
                BatchEmbeddingOut,
                BatchTokenIDOut,
                UpdateWeightFromDiskReqOutput,
706
                UpdateWeightsFromDistributedReqOutput,
707
                GetWeightsByNameReqOutput,
708
                InitWeightsUpdateGroupReqOutput,
709
710
                ReleaseMemoryOccupationReqOutput,
                ResumeMemoryOccupationReqOutput,
711
712
            ] = await self.recv_from_detokenizer.recv_pyobj()

Lianmin Zheng's avatar
Lianmin Zheng committed
713
714
715
716
717
718
            if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
                for i, rid in enumerate(recv_obj.rids):
                    state = self.rid_to_state.get(rid, None)
                    if state is None:
                        continue

Lianmin Zheng's avatar
Lianmin Zheng committed
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
                    meta_info = {
                        "id": rid,
                        "finish_reason": recv_obj.finished_reasons[i],
                        "prompt_tokens": recv_obj.prompt_tokens[i],
                    }

                    if getattr(state.obj, "return_logprob", False):
                        self.convert_logprob_style(
                            meta_info,
                            state.obj.top_logprobs_num,
                            state.obj.return_text_in_logprobs,
                            recv_obj,
                            i,
                        )

734
735
736
737
738
739
740
741
                    if not isinstance(recv_obj, BatchEmbeddingOut):
                        meta_info.update(
                            {
                                "completion_tokens": recv_obj.completion_tokens[i],
                                "cached_tokens": recv_obj.cached_tokens[i],
                            }
                        )

Lianmin Zheng's avatar
Lianmin Zheng committed
742
743
744
                    if isinstance(recv_obj, BatchStrOut):
                        out_dict = {
                            "text": recv_obj.output_strs[i],
745
                            "meta_info": meta_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
746
747
748
749
                        }
                    elif isinstance(recv_obj, BatchTokenIDOut):
                        out_dict = {
                            "token_ids": recv_obj.output_ids[i],
750
                            "meta_info": meta_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
751
752
753
754
755
                        }
                    else:
                        assert isinstance(recv_obj, BatchEmbeddingOut)
                        out_dict = {
                            "embedding": recv_obj.embeddings[i],
Lianmin Zheng's avatar
Lianmin Zheng committed
756
                            "meta_info": meta_info,
Lianmin Zheng's avatar
Lianmin Zheng committed
757
758
                        }
                    state.out_list.append(out_dict)
Lianmin Zheng's avatar
Lianmin Zheng committed
759
                    state.finished = recv_obj.finished_reasons[i] is not None
Lianmin Zheng's avatar
Lianmin Zheng committed
760
761
                    state.event.set()

762
                    if self.enable_metrics and state.obj.log_metrics:
763
                        self.collect_metrics(state, recv_obj, i)
764
765
766
767
768
                    if (
                        self.dump_requests_folder
                        and state.finished
                        and state.obj.log_metrics
                    ):
769
                        self.dump_requests(state, out_dict)
Lianmin Zheng's avatar
Lianmin Zheng committed
770
771
            elif isinstance(recv_obj, OpenSessionReqOutput):
                self.session_futures[recv_obj.session_id].set_result(
772
                    recv_obj.session_id if recv_obj.success else None
Lianmin Zheng's avatar
Lianmin Zheng committed
773
774
                )
            elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
775
776
                if self.server_args.dp_size == 1:
                    self.model_update_result.set_result(recv_obj)
Chayenne's avatar
Chayenne committed
777
                else:  # self.server_args.dp_size > 1
778
779
780
781
                    self.model_update_tmp.append(recv_obj)
                    # set future if the all results are recevied
                    if len(self.model_update_tmp) == self.server_args.dp_size:
                        self.model_update_result.set_result(self.model_update_tmp)
Lianmin Zheng's avatar
Lianmin Zheng committed
782
783
784
785
            elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
                assert (
                    self.server_args.dp_size == 1
                ), "dp_size must be 1 for init parameter update group"
786
                self.init_weights_update_group_communicator.handle_recv(recv_obj)
787
788
789
790
            elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
                assert (
                    self.server_args.dp_size == 1
                ), "dp_size must be 1 for update weights from distributed"
791
                self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
792
793
794
795
796
            elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput):
                assert (
                    self.server_args.dp_size == 1
                ), "dp_size must be 1 for update weights from distributed"
                self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
797
            elif isinstance(recv_obj, GetWeightsByNameReqOutput):
798
                self.get_weights_by_name_communicator.handle_recv(recv_obj)
799
800
801
802
            elif isinstance(recv_obj, ReleaseMemoryOccupationReqOutput):
                self.release_memory_occupation_communicator.handle_recv(recv_obj)
            elif isinstance(recv_obj, ResumeMemoryOccupationReqOutput):
                self.resume_memory_occupation_communicator.handle_recv(recv_obj)
Lianmin Zheng's avatar
Lianmin Zheng committed
803
804
            else:
                raise ValueError(f"Invalid object: {recv_obj=}")
805

Liangsheng Yin's avatar
Liangsheng Yin committed
806
    def convert_logprob_style(
807
        self,
Lianmin Zheng's avatar
Lianmin Zheng committed
808
        meta_info: dict,
809
810
        top_logprobs_num: int,
        return_text_in_logprobs: bool,
Lianmin Zheng's avatar
Lianmin Zheng committed
811
812
        recv_obj: BatchStrOut,
        recv_obj_index: int,
Liangsheng Yin's avatar
Liangsheng Yin committed
813
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
        meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
            recv_obj.input_token_logprobs_val[recv_obj_index],
            recv_obj.input_token_logprobs_idx[recv_obj_index],
            return_text_in_logprobs,
        )
        meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
            recv_obj.output_token_logprobs_val[recv_obj_index],
            recv_obj.output_token_logprobs_idx[recv_obj_index],
            return_text_in_logprobs,
        )

        if top_logprobs_num > 0:
            meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
                recv_obj.input_top_logprobs_val[recv_obj_index],
                recv_obj.input_top_logprobs_idx[recv_obj_index],
                return_text_in_logprobs,
830
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
831
832
833
834
            meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
                recv_obj.output_top_logprobs_val[recv_obj_index],
                recv_obj.output_top_logprobs_idx[recv_obj_index],
                return_text_in_logprobs,
835
            )
836

837
    def detokenize_logprob_tokens(
Lianmin Zheng's avatar
Lianmin Zheng committed
838
839
840
841
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
842
    ):
843
        if not decode_to_text:
Lianmin Zheng's avatar
Lianmin Zheng committed
844
845
846
847
848
849
850
851
            return [
                (logprob, token_id, None)
                for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
            ]
        else:
            assert self.tokenizer is not None
            token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
            return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))
852

Lianmin Zheng's avatar
Lianmin Zheng committed
853
854
855
856
857
858
    def detokenize_top_logprobs_tokens(
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
    ):
859
860
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
Lianmin Zheng's avatar
Lianmin Zheng committed
861
862
863
864
865
866
867
        ret = []
        for i in range(len(token_logprobs_val)):
            if token_logprobs_val[i]:
                ret.append(
                    self.detokenize_logprob_tokens(
                        token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
                    )
868
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
869
870
871
            else:
                ret.append(None)
        return ret
872

873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
    def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
        completion_tokens = (
            recv_obj.completion_tokens[i]
            if getattr(recv_obj, "completion_tokens", None)
            else 0
        )

        if state.first_token_time is None:
            state.first_token_time = time.time()
            self.metrics_collector.observe_time_to_first_token(
                state.first_token_time - state.created_time
            )
        else:
            if completion_tokens >= 2:
                # Compute time_per_output_token for the streaming case
                self.metrics_collector.observe_time_per_output_token(
                    (time.time() - state.first_token_time) / (completion_tokens - 1)
                )

        if state.finished:
            self.metrics_collector.observe_one_finished_request(
                recv_obj.prompt_tokens[i], completion_tokens
            )
            self.metrics_collector.observe_e2e_request_latency(
                time.time() - state.created_time
            )
            # Compute time_per_output_token for the non-streaming case
            if (
                hasattr(state.obj, "stream")
                and not state.obj.stream
                and completion_tokens >= 1
            ):
                self.metrics_collector.observe_time_per_output_token(
                    (time.time() - state.created_time) / completion_tokens
                )

909
910
911
912
913
914
    def dump_requests(self, state: ReqState, out_dict: dict):
        self.dump_request_list.append(
            (state.obj, out_dict, state.created_time, time.time())
        )

        if len(self.dump_request_list) >= self.dump_requests_threshold:
915
916
917
918
919
920
            filename = os.path.join(
                self.dump_requests_folder,
                datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
            )
            logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")

921
922
923
924
925
            to_dump = self.dump_request_list
            self.dump_request_list = []

            def background_task():
                os.makedirs(self.dump_requests_folder, exist_ok=True)
926
                with open(filename, "wb") as f:
927
928
929
930
931
                    pickle.dump(to_dump, f)

            # Schedule the task to run in the background without awaiting it
            asyncio.create_task(asyncio.to_thread(background_task))

932

933
934
935
936
937
938
939
940
941
942
943
944
945
946
async def print_exception_wrapper(func):
    """
    Sometimes an asyncio function does not print exception.
    We do another wrapper to handle the exception.
    """
    try:
        await func()
    except Exception:
        traceback = get_exception_traceback()
        logger.error(f"TokenizerManager hit an exception: {traceback}")
        kill_process_tree(os.getpid(), include_parent=True)
        sys.exit(1)


947
948
949
950
951
952
953
954
955
class SignalHandler:
    def __init__(self, tokenizer_manager):
        self.tokenizer_manager = tokenizer_manager

    def signal_handler(self, signum=None, frame=None):
        logger.warning(
            f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
        )
        self.tokenizer_manager.gracefully_exit = True
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980


T = TypeVar("T")


class _Communicator(Generic[T]):
    def __init__(self, sender, fan_out: int):
        self._sender = sender
        self._fan_out = fan_out
        self._result_future: Optional[asyncio.Future] = None
        self._result_values: Optional[List[T]] = None

    async def __call__(self, obj):
        self._sender.send_pyobj(obj)
        self._result_future = asyncio.Future()
        self._result_values = []
        await self._result_future
        result_values = self._result_values
        self._result_future = self._result_values = None
        return result_values

    def handle_recv(self, recv_obj: T):
        self._result_values.append(recv_obj)
        if len(self._result_values) == self._fan_out:
            self._result_future.set_result(None)