tokenizer_manager.py 58 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
"""TokenizerManager is a process that tokenizes the text."""
15

Lianmin Zheng's avatar
Lianmin Zheng committed
16
import asyncio
17
18
import copy
import dataclasses
19
import json
20
import logging
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import os
22
import pickle
23
24
import signal
import sys
25
import threading
26
import time
27
import uuid
28
from collections import deque
29
30
from datetime import datetime
from http import HTTPStatus
31
32
33
34
35
36
37
38
39
40
41
42
from typing import (
    Any,
    Awaitable,
    Deque,
    Dict,
    Generic,
    List,
    Optional,
    Tuple,
    TypeVar,
    Union,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
43

44
import fastapi
Lianmin Zheng's avatar
Lianmin Zheng committed
45
46
47
import uvloop
import zmq
import zmq.asyncio
48
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
49

50
from sglang.srt.aio_rwlock import RWLock
51
from sglang.srt.configs.model_config import ModelConfig
52
53
54
55
56
57
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
    KVClassType,
    TransferBackend,
    get_kv_class,
)
xm:D's avatar
xm:D committed
58
59
60
61
62
from sglang.srt.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
63
from sglang.srt.managers.io_struct import (
64
    AbortReq,
65
    BatchEmbeddingOut,
66
    BatchMultimodalOut,
Lianmin Zheng's avatar
Lianmin Zheng committed
67
    BatchStrOut,
68
    BatchTokenIDOut,
69
    CloseSessionReqInput,
70
    ConfigureLoggingReq,
71
    EmbeddingReqInput,
72
    ExpertDistributionReq,
73
    ExpertDistributionReqOutput,
74
75
    FlushCacheReqInput,
    FlushCacheReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
76
    GenerateReqInput,
77
78
    GetInternalStateReq,
    GetInternalStateReqOutput,
79
80
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
81
    HealthCheckOutput,
82
83
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
84
85
    OpenSessionReqInput,
    OpenSessionReqOutput,
86
    ProfileReq,
87
88
    ProfileReqOutput,
    ProfileReqType,
89
90
91
92
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
93
    SessionParams,
94
95
    SetInternalStateReq,
    SetInternalStateReqOutput,
96
97
    SlowDownReqInput,
    SlowDownReqOutput,
98
99
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
100
101
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
102
103
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
104
105
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
106
)
Mick's avatar
Mick committed
107
108
109
110
111
from sglang.srt.managers.multimodal_processor import (
    get_dummy_processor,
    get_mm_processor,
    import_processors,
)
112
113
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
114
from sglang.srt.server_args import PortArgs, ServerArgs
115
116
117
118
119
from sglang.srt.utils import (
    dataclass_to_string_truncated,
    get_zmq_socket,
    kill_process_tree,
)
120
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
Lianmin Zheng's avatar
Lianmin Zheng committed
121
122
123

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

124
125
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
126

127
128
129
130
@dataclasses.dataclass
class ReqState:
    """Store the state a request."""

131
    out_list: List[Dict[Any, Any]]
132
133
    finished: bool
    event: asyncio.Event
134
    obj: Union[GenerateReqInput, EmbeddingReqInput]
135
136
137

    # For metrics
    created_time: float
138
139
140
141
    finished_time: float = 0.0
    first_token_time: float = 0.0
    last_time: float = 0.0
    last_completion_tokens: int = 1
142
143
144

    # For streaming output
    last_output_offset: int = 0
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    # For incremental state update.
    text: str = ""
    output_ids: List[int] = dataclasses.field(default_factory=list)
    input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
    input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
    output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
    output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
    input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
    input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
    output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
    output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
    input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
    input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
    output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
    output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
160
161


162
163
class TokenizerManager:
    """TokenizerManager is a process that tokenizes the text."""
164

Lianmin Zheng's avatar
Lianmin Zheng committed
165
166
167
168
169
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
170
        # Parse args
Liangsheng Yin's avatar
Liangsheng Yin committed
171
        self.server_args = server_args
172
        self.enable_metrics = server_args.enable_metrics
173
        self.log_requests = server_args.log_requests
174
        self.log_requests_level = server_args.log_requests_level
175
176
177
178
179
        self.preferred_sampling_params = (
            json.loads(server_args.preferred_sampling_params)
            if server_args.preferred_sampling_params
            else None
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
180

181
        # Init inter-process communication
Lianmin Zheng's avatar
Lianmin Zheng committed
182
        context = zmq.asyncio.Context(2)
183
        self.recv_from_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
184
            context, zmq.PULL, port_args.tokenizer_ipc_name, True
185
186
        )
        self.send_to_scheduler = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
187
            context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
188
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
189

190
        # Read model args
Lianmin Zheng's avatar
Lianmin Zheng committed
191
        self.model_path = server_args.model_path
192
        self.served_model_name = server_args.served_model_name
193
        self.model_config = ModelConfig.from_server_args(server_args)
194
195

        self.is_generation = self.model_config.is_generation
196
        self.is_image_gen = self.model_config.is_image_gen
197
198
199
        self.context_len = self.model_config.context_len
        self.image_token_id = self.model_config.image_token_id

200
        if self.model_config.is_multimodal:
Mick's avatar
Mick committed
201
            import_processors()
202
203
204
205
206
            _processor = get_processor(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
                revision=server_args.revision,
207
                use_fast=not server_args.disable_fast_image_processor,
208
209
210
            )

            # We want to parallelize the image pre-processing so we create an executor for it
Mick's avatar
Mick committed
211
            # We create mm_processor for any skip_tokenizer_init to make sure we still encode
212
            # images even with skip_tokenizer_init=False.
Mick's avatar
Mick committed
213
            self.mm_processor = get_mm_processor(
214
215
216
217
218
219
220
                self.model_config.hf_config, server_args, _processor
            )

            if server_args.skip_tokenizer_init:
                self.tokenizer = self.processor = None
            else:
                self.processor = _processor
xm:D's avatar
xm:D committed
221
                self.tokenizer = get_tokenizer_from_processor(self.processor)
222
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
223
        else:
Mick's avatar
Mick committed
224
            self.mm_processor = get_dummy_processor()
225

226
227
            if server_args.skip_tokenizer_init:
                self.tokenizer = self.processor = None
228
229
230
231
232
233
234
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
235

236
        # Store states
237
        self.no_create_loop = False
238
        self.rid_to_state: Dict[str, ReqState] = {}
239
        self.health_check_failed = False
240
241
        self.gracefully_exit = False
        self.last_receive_tstamp = 0
242
243
244
        self.dump_requests_folder = ""  # By default do not dump
        self.dump_requests_threshold = 1000
        self.dump_request_list: List[Tuple] = []
245
        self.log_request_metadata = self.get_log_request_metadata()
Lianmin Zheng's avatar
Lianmin Zheng committed
246

247
248
249
250
251
252
        # The event to notify the weight sync is finished.
        self.model_update_lock = RWLock()
        self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
            None
        )
        self.asyncio_tasks = set()
253

254
255
256
        # For session info
        self.session_futures = {}  # session_id -> asyncio event

257
258
259
260
261
262
263
264
        # Set after scheduler is initialized
        self.max_req_input_len = None

        # Metrics
        if self.enable_metrics:
            self.metrics_collector = TokenizerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
265
                    # TODO: Add lora name/path in the future,
266
                },
267
268
269
270
                bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
                bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
                bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
                collect_tokens_histogram=self.server_args.collect_tokens_histogram,
271
272
273
            )

        # Communicators
274
275
276
277
278
279
        self.init_weights_update_group_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.update_weights_from_distributed_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
280
281
282
        self.update_weights_from_tensor_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
283
284
285
        self.get_weights_by_name_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
286
287
288
289
290
291
        self.release_memory_occupation_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.resume_memory_occupation_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
292
293
294
        self.slow_down_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
295
296
297
        self.flush_cache_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
298
        self.profile_communicator = _Communicator(
299
300
            self.send_to_scheduler, server_args.dp_size
        )
301
        self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
302
303
304
        self.get_internal_state_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
305
306
307
        self.set_internal_state_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
308
309
310
        self.expert_distribution_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
311

312
        self._result_dispatcher = TypeBasedDispatcher(
313
            [
314
                (
315
316
317
318
319
320
                    (
                        BatchStrOut,
                        BatchEmbeddingOut,
                        BatchTokenIDOut,
                        BatchMultimodalOut,
                    ),
321
                    self._handle_batch_output,
322
                ),
Lianmin Zheng's avatar
Lianmin Zheng committed
323
                (AbortReq, self._handle_abort_req),
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
                (OpenSessionReqOutput, self._handle_open_session_req_output),
                (
                    UpdateWeightFromDiskReqOutput,
                    self._handle_update_weights_from_disk_req_output,
                ),
                (
                    InitWeightsUpdateGroupReqOutput,
                    self.init_weights_update_group_communicator.handle_recv,
                ),
                (
                    UpdateWeightsFromDistributedReqOutput,
                    self.update_weights_from_distributed_communicator.handle_recv,
                ),
                (
                    UpdateWeightsFromTensorReqOutput,
                    self.update_weights_from_tensor_communicator.handle_recv,
                ),
                (
                    GetWeightsByNameReqOutput,
                    self.get_weights_by_name_communicator.handle_recv,
                ),
                (
                    ReleaseMemoryOccupationReqOutput,
                    self.release_memory_occupation_communicator.handle_recv,
                ),
                (
                    ResumeMemoryOccupationReqOutput,
                    self.resume_memory_occupation_communicator.handle_recv,
                ),
353
354
355
356
                (
                    SlowDownReqOutput,
                    self.slow_down_communicator.handle_recv,
                ),
357
358
359
360
                (
                    FlushCacheReqOutput,
                    self.flush_cache_communicator.handle_recv,
                ),
361
362
                (
                    ProfileReqOutput,
363
                    self.profile_communicator.handle_recv,
364
365
366
367
368
                ),
                (
                    GetInternalStateReqOutput,
                    self.get_internal_state_communicator.handle_recv,
                ),
369
370
371
372
                (
                    SetInternalStateReqOutput,
                    self.set_internal_state_communicator.handle_recv,
                ),
373
374
375
376
                (
                    ExpertDistributionReqOutput,
                    self.expert_distribution_communicator.handle_recv,
                ),
377
                (HealthCheckOutput, lambda x: None),
378
379
380
            ]
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
381
        # For pd disaggregtion
Byron Hsu's avatar
Byron Hsu committed
382
383
384
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
385
386
387
        self.transfer_backend = TransferBackend(
            self.server_args.disaggregation_transfer_backend
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
388
        # Start kv boostrap server on prefill
Byron Hsu's avatar
Byron Hsu committed
389
390
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            # only start bootstrap server on prefill tm
391
392
393
394
            kv_bootstrap_server_class = get_kv_class(
                self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
            )
            self.bootstrap_server = kv_bootstrap_server_class(
Byron Hsu's avatar
Byron Hsu committed
395
396
397
                self.server_args.disaggregation_bootstrap_port
            )

398
    async def generate_request(
399
        self,
400
        obj: Union[GenerateReqInput, EmbeddingReqInput],
401
        request: Optional[fastapi.Request] = None,
402
    ):
403
404
        created_time = time.time()

405
        self.auto_create_handle_loop()
406
407
408
409
410
411
412
413
414
415

        if isinstance(obj, EmbeddingReqInput) and self.is_generation:
            raise ValueError(
                "This model does not appear to be an embedding model by default. "
                "Please add `--is-embedding` when launching the server or try another model."
            )

        obj.normalize_batch_and_arguments()

        if self.log_requests:
416
            max_length, skip_names, _ = self.log_request_metadata
417
            logger.info(
418
                f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
419
420
            )

421
        async with self.model_update_lock.reader_lock:
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
            is_single = obj.is_single
            if is_single:
                tokenized_obj = await self._tokenize_one_request(obj)
                self._send_one_request(obj, tokenized_obj, created_time)
                async for response in self._wait_one_response(obj, request):
                    yield response
            else:
                async for response in self._handle_batch_request(
                    obj, request, created_time
                ):
                    yield response

    async def _tokenize_one_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
    ):
        """Tokenize one request."""
        # Tokenize
        input_embeds = None
        input_text = obj.text
        if obj.input_embeds is not None:
            if not self.server_args.disable_radix_cache:
                raise ValueError(
                    "input_embeds is provided while disable_radix_cache is False. "
                    "Please add `--disable-radix-cache` when you launch the server "
                    "if you want to use input_embeds as inputs."
                )
            input_embeds = obj.input_embeds
            input_ids = obj.input_ids
        elif obj.input_ids is not None:
            input_ids = obj.input_ids
        else:
            if self.tokenizer is None:
                raise ValueError(
                    "The engine initialized with skip_tokenizer_init=True cannot "
                    "accept text prompts. Please provide input_ids or re-initialize "
                    "the engine with skip_tokenizer_init=False."
                )
            input_ids = self.tokenizer.encode(input_text)

462
463
464
465
466
467
468
469
470
471
        image_inputs: Optional[Dict] = None
        if obj.contains_mm_input():
            image_inputs = await self.mm_processor.process_mm_data_async(
                image_data=obj.image_data,
                input_text=input_text or input_ids,
                request_obj=obj,
                max_req_input_len=self.max_req_input_len,
            )
            if image_inputs and "input_ids" in image_inputs:
                input_ids = image_inputs["input_ids"]
472
473
474
475
476
477
478
479
480
481

        self._validate_token_len(obj, input_ids)
        return self._create_tokenized_object(
            obj, input_text, input_ids, input_embeds, image_inputs
        )

    def _validate_token_len(
        self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
    ) -> None:
        """Validates that the input token count and the requested token count doesn't exceed the model's context length."""
482
483

        input_token_num = len(input_ids) if input_ids is not None else 0
484
        # Check if input alone exceeds context length
485
486
487
488
489
490
        if input_token_num >= self.context_len:
            raise ValueError(
                f"The input ({input_token_num} tokens) is longer than the "
                f"model's context length ({self.context_len} tokens)."
            )

491
492
        # Check total tokens (input + max_new_tokens)
        max_new_tokens = obj.sampling_params.get("max_new_tokens")
493
        if (
494
495
            max_new_tokens is not None
            and (max_new_tokens + input_token_num) >= self.context_len
496
        ):
497
498
            total_tokens = max_new_tokens + input_token_num
            error_msg = (
499
                f"Requested token count exceeds the model's maximum context length "
500
                f"of {self.context_len} tokens. You requested a total of {total_tokens} "
501
                f"tokens: {input_token_num} tokens from the input messages and "
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
                f"{max_new_tokens} tokens for the completion. Please reduce the number "
                f"of tokens in the input messages or the completion to fit within the limit."
            )
            raise ValueError(error_msg)

    def _create_tokenized_object(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        input_text: str,
        input_ids: List[int],
        input_embeds: Optional[Union[List[float], None]] = None,
        image_inputs: Optional[Dict] = None,
    ) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
        """Create a tokenized request object from common parameters."""

        if self.is_generation:
            return_logprob = obj.return_logprob
            logprob_start_len = obj.logprob_start_len
            top_logprobs_num = obj.top_logprobs_num
            token_ids_logprob = obj.token_ids_logprob
            session_params = (
                SessionParams(**obj.session_params) if obj.session_params else None
524
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
525
526
527
528
529
530
531
532
            if (
                obj.custom_logit_processor
                and not self.server_args.enable_custom_logit_processor
            ):
                raise ValueError(
                    "The server is not configured to enable custom logit processor. "
                    "Please set `--enable-custom-logits-processor` to enable this feature."
                )
533

534
535
536
537
538
539
540
541
        # Parse sampling parameters
        # Note: if there are preferred sampling params, we use them if they are not
        # explicitly passed in sampling_params
        if self.preferred_sampling_params:
            sampling_kwargs = {**self.preferred_sampling_params, **obj.sampling_params}
        else:
            sampling_kwargs = obj.sampling_params
        sampling_params = SamplingParams(**sampling_kwargs)
542
543
544
545
546
547
548
549
550
551
552
553
554
555
        sampling_params.normalize(self.tokenizer)
        sampling_params.verify()

        # Build return object
        if isinstance(obj, GenerateReqInput):
            tokenized_obj = TokenizedGenerateReqInput(
                obj.rid,
                input_text,
                input_ids,
                image_inputs,
                sampling_params,
                return_logprob,
                logprob_start_len,
                top_logprobs_num,
556
                token_ids_logprob,
557
                obj.stream,
558
                bootstrap_host=obj.bootstrap_host,
559
                bootstrap_port=obj.bootstrap_port,
560
                bootstrap_room=obj.bootstrap_room,
561
562
563
564
                lora_path=obj.lora_path,
                input_embeds=input_embeds,
                session_params=session_params,
                custom_logit_processor=obj.custom_logit_processor,
565
                return_hidden_states=obj.return_hidden_states,
566
567
568
569
570
571
            )
        elif isinstance(obj, EmbeddingReqInput):
            tokenized_obj = TokenizedEmbeddingReqInput(
                obj.rid,
                input_text,
                input_ids,
572
                image_inputs,
573
574
575
576
577
                sampling_params,
            )

        return tokenized_obj

578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
    async def _batch_tokenize_and_process(
        self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
    ) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]:
        """Handle batch tokenization for text inputs only."""
        logger.debug(f"Starting batch tokenization for {batch_size} text requests")

        # Collect requests and texts
        requests = [obj[i] for i in range(batch_size)]
        texts = [req.text for req in requests]

        # Batch tokenize all texts
        encoded = self.tokenizer(texts)
        input_ids_list = encoded["input_ids"]

        # Process all requests
        tokenized_objs = []
        for i, req in enumerate(requests):
            self._validate_token_len(obj[i], input_ids_list[i])
            tokenized_objs.append(
                self._create_tokenized_object(
                    req, req.text, input_ids_list[i], None, None
                )
            )
        logger.debug(f"Completed batch processing for {batch_size} requests")
        return tokenized_objs

    def _validate_batch_tokenization_constraints(
        self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
    ) -> None:
        """Validate constraints for batch tokenization processing."""
        for i in range(batch_size):
            if self.is_generation and obj[i].image_data:
                raise ValueError(
                    "For image input processing do not set `enable_tokenizer_batch_encode`."
                )
            if obj[i].input_ids is not None:
                raise ValueError(
                    "Batch tokenization is not needed for pre-tokenized input_ids. Do not set `enable_tokenizer_batch_encode`."
                )
            if obj[i].input_embeds is not None:
                raise ValueError(
                    "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
                )

622
623
624
625
626
627
    def _send_one_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
        created_time: Optional[float] = None,
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
628
        self.send_to_scheduler.send_pyobj(tokenized_obj)
629
        state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
        self.rid_to_state[obj.rid] = state

    async def _wait_one_response(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
    ):
        """Wait for the response of one request."""
        state = self.rid_to_state[obj.rid]

        while True:
            try:
                await asyncio.wait_for(state.event.wait(), timeout=4)
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
Lianmin Zheng's avatar
Lianmin Zheng committed
645
                    # Abort the request for disconnected requests (non-streaming, waiting queue)
646
                    self.abort_request(obj.rid)
Lianmin Zheng's avatar
Lianmin Zheng committed
647
                    # Use exception to kill the whole call stack and asyncio task
648
                    raise ValueError(
Lianmin Zheng's avatar
Lianmin Zheng committed
649
                        f"Request is disconnected from the client side (type 1). Abort request {obj.rid=}"
650
                    )
651
652
653
654
655
656
657
                continue

            out = state.out_list[-1]

            state.out_list = []
            if state.finished:
                if self.log_requests:
658
659
660
661
662
                    max_length, skip_names, out_skip_names = self.log_request_metadata
                    if self.model_config.is_multimodal_gen:
                        msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
                    else:
                        msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
                    logger.info(msg)

                # Check if this was an abort/error created by scheduler
                if isinstance(out["meta_info"].get("finish_reason"), dict):
                    finish_reason = out["meta_info"]["finish_reason"]
                    if (
                        finish_reason.get("type") == "abort"
                        and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST
                    ):
                        raise ValueError(finish_reason["message"])

                yield out
                break

            state.event.clear()

            if obj.stream:
                yield out
            else:
                if request is not None and await request.is_disconnected():
Lianmin Zheng's avatar
Lianmin Zheng committed
683
                    # Abort the request for disconnected requests (non-streaming, running)
684
                    self.abort_request(obj.rid)
Lianmin Zheng's avatar
Lianmin Zheng committed
685
                    # Use exception to kill the whole call stack and asyncio task
686
                    raise ValueError(
Lianmin Zheng's avatar
Lianmin Zheng committed
687
                        f"Request is disconnected from the client side (type 3). Abort request {obj.rid=}"
688
                    )
689
690
691
692
693
694
695
696
697
698
699
700

    async def _handle_batch_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
        created_time: Optional[float] = None,
    ):
        batch_size = obj.batch_size

        generators = []
        rids = []
        if getattr(obj, "parallel_sample_num", 1) == 1:
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
            if self.server_args.enable_tokenizer_batch_encode:
                # Validate batch tokenization constraints
                self._validate_batch_tokenization_constraints(batch_size, obj)

                tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)

                for i, tokenized_obj in enumerate(tokenized_objs):
                    tmp_obj = obj[i]
                    self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, request))
                    rids.append(tmp_obj.rid)
            else:
                # Sequential tokenization and processing
                for i in range(batch_size):
                    tmp_obj = obj[i]
                    tokenized_obj = await self._tokenize_one_request(tmp_obj)
                    self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, request))
                    rids.append(tmp_obj.rid)
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
        else:
            # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
            if batch_size > 128:
                logger.warning(
                    "Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
                    "The performance might be better if you just duplicate the requests n times or use "
                    "many threads to send them one by one with parallel sampling (n > 1)."
                )

            # Tokenize all requests
            objs = [obj[i] for i in range(batch_size)]
            tokenized_objs = await asyncio.gather(
                *(self._tokenize_one_request(obj) for obj in objs)
            )

            # Cache the common prefix for parallel sampling
            for i in range(batch_size):
                tmp_obj = copy.copy(objs[i])
                tokenized_obj = copy.copy(tokenized_objs[i])
                tokenized_obj.rid = tmp_obj.regenerate_rid()
                tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
                tokenized_obj.sampling_params.max_new_tokens = 0
                tokenized_obj.stream = False
                self._send_one_request(tmp_obj, tokenized_obj, created_time)
                await self._wait_one_response(tmp_obj, request).__anext__()

            # Expand requests, assign new rids for them, and send them
            for i in range(batch_size):
                for _ in range(obj.parallel_sample_num):
                    tmp_obj = copy.copy(objs[i])
                    tokenized_obj = copy.copy(tokenized_objs[i])
                    tokenized_obj.rid = tmp_obj.regenerate_rid()
                    self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, request))
                    rids.append(tmp_obj.rid)

        # Wait for all requests
        is_stream = hasattr(obj, "stream") and obj.stream
        if not is_stream:
            outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
            yield outputs
        else:
            rid_to_index = {rid: i for i, rid in enumerate(rids)}
            task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
            while task_map:
                done, _ = await asyncio.wait(
                    task_map.keys(), return_when=asyncio.FIRST_COMPLETED
                )

                for task in done:
                    gen = task_map.pop(task)
                    try:
                        result = task.result()
                        result["index"] = rid_to_index[result["meta_info"]["id"]]
                        yield result
                        new_task = asyncio.create_task(gen.__anext__())
                        task_map[new_task] = gen
                    except StopAsyncIteration:
                        pass
779

780
    async def flush_cache(self) -> FlushCacheReqOutput:
Lianmin Zheng's avatar
Lianmin Zheng committed
781
        return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
Liangsheng Yin's avatar
Liangsheng Yin committed
782

783
    def abort_request(self, rid: str):
784
785
786
787
        if rid not in self.rid_to_state:
            return
        req = AbortReq(rid)
        self.send_to_scheduler.send_pyobj(req)
788

789
790
791
792
793
    async def start_profile(
        self,
        output_dir: Optional[str] = None,
        num_steps: Optional[int] = None,
        activities: Optional[List[str]] = None,
794
795
        with_stack: Optional[bool] = None,
        record_shapes: Optional[bool] = None,
796
    ):
797
        self.auto_create_handle_loop()
798
799
800
801
802
        req = ProfileReq(
            type=ProfileReqType.START_PROFILE,
            output_dir=output_dir,
            num_steps=num_steps,
            activities=activities,
803
804
            with_stack=with_stack,
            record_shapes=record_shapes,
805
            profile_id=str(time.time()),
806
        )
807
808
809
        return await self._execute_profile(req)

    async def stop_profile(self):
810
        self.auto_create_handle_loop()
811
812
813
814
815
        req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
        return await self._execute_profile(req)

    async def _execute_profile(self, req: ProfileReq):
        result = (await self.profile_communicator(req))[0]
816
817
818
        if not result.success:
            raise RuntimeError(result.message)
        return result
819

820
    async def start_expert_distribution_record(self):
821
        self.auto_create_handle_loop()
822
        await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
823

824
    async def stop_expert_distribution_record(self):
825
        self.auto_create_handle_loop()
826
        await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
827

828
    async def dump_expert_distribution_record(self):
829
        self.auto_create_handle_loop()
830
        await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
831

Chayenne's avatar
Chayenne committed
832
833
834
835
    async def update_weights_from_disk(
        self,
        obj: UpdateWeightFromDiskReqInput,
        request: Optional[fastapi.Request] = None,
836
    ) -> Tuple[bool, str]:
837
        self.auto_create_handle_loop()
838
839
840
841

        # default the load format to the server_args
        if obj.load_format is None:
            obj.load_format = self.server_args.load_format
842
        logger.info("Start update_weights. Load format=%s", obj.load_format)
843

844
845
846
847
848
        if True:
            # Hold the lock if it is not async. This means that weight sync
            # cannot run while requests are in progress.
            async with self.model_update_lock.writer_lock:
                return await self._wait_for_model_update_from_disk(obj)
849

850
851
    async def _wait_for_model_update_from_disk(
        self, obj: UpdateWeightFromDiskReqInput
852
    ) -> Tuple[bool, str]:
853
854
855
856
857
858
859
860
861
        self.send_to_scheduler.send_pyobj(obj)
        self.model_update_result = asyncio.Future()
        if self.server_args.dp_size == 1:
            result = await self.model_update_result
            if result.success:
                self.served_model_name = obj.model_path
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
862
            return result.success, result.message, result.num_paused_requests
863
864
865
866
867
868
869
870
871
872
873
        else:  # self.server_args.dp_size > 1
            self.model_update_tmp = []
            result = await self.model_update_result

            all_success = all([r.success for r in result])
            if all_success is True:
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
            all_message = [r.message for r in result]
            all_message = " | ".join(all_message)
874
875
            all_paused_requests = [r.num_paused_requests for r in result]
            return all_success, all_message, all_paused_requests
876

877
878
879
880
    async def init_weights_update_group(
        self,
        obj: InitWeightsUpdateGroupReqInput,
        request: Optional[fastapi.Request] = None,
881
    ) -> Tuple[bool, str]:
882
        self.auto_create_handle_loop()
883
884
885
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be 1 for init parameter update group"
886
        result = (await self.init_weights_update_group_communicator(obj))[0]
887
888
889
890
891
892
        return result.success, result.message

    async def update_weights_from_distributed(
        self,
        obj: UpdateWeightsFromDistributedReqInput,
        request: Optional[fastapi.Request] = None,
893
    ) -> Tuple[bool, str]:
894
895
896
        self.auto_create_handle_loop()
        assert (
            self.server_args.dp_size == 1
897
        ), "dp_size must be 1 for update weights from distributed"
898

899
900
901
        # This means that weight sync
        # cannot run while requests are in progress.
        async with self.model_update_lock.writer_lock:
902
            result = (await self.update_weights_from_distributed_communicator(obj))[0]
903
            return result.success, result.message
904

905
906
907
908
909
910
911
912
    async def update_weights_from_tensor(
        self,
        obj: UpdateWeightsFromTensorReqInput,
        request: Optional[fastapi.Request] = None,
    ) -> Tuple[bool, str]:
        self.auto_create_handle_loop()
        assert (
            self.server_args.dp_size == 1
913
        ), "dp_size must be 1 for update weights from distributed"
914
915
916
917
918
919
920

        # This means that weight sync
        # cannot run while requests are in progress.
        async with self.model_update_lock.writer_lock:
            result = (await self.update_weights_from_tensor_communicator(obj))[0]
            return result.success, result.message

921
922
923
    async def get_weights_by_name(
        self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
    ):
924
925
926
        self.auto_create_handle_loop()
        results = await self.get_weights_by_name_communicator(obj)
        all_parameters = [r.parameter for r in results]
927
        if self.server_args.dp_size == 1:
928
            return all_parameters[0]
929
930
931
        else:
            return all_parameters

932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
    async def release_memory_occupation(
        self,
        obj: ReleaseMemoryOccupationReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.release_memory_occupation_communicator(obj)

    async def resume_memory_occupation(
        self,
        obj: ResumeMemoryOccupationReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.resume_memory_occupation_communicator(obj)

948
949
950
951
952
953
954
955
    async def slow_down(
        self,
        obj: SlowDownReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.slow_down_communicator(obj)

956
957
958
    async def open_session(
        self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
    ):
959
        self.auto_create_handle_loop()
960

961
962
963
964
965
        if obj.session_id is None:
            obj.session_id = uuid.uuid4().hex
        elif obj.session_id in self.session_futures:
            return None

966
        self.send_to_scheduler.send_pyobj(obj)
967
968
969
970

        self.session_futures[obj.session_id] = asyncio.Future()
        session_id = await self.session_futures[obj.session_id]
        del self.session_futures[obj.session_id]
971
972
973
974
975
976
977
        return session_id

    async def close_session(
        self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
    ):
        await self.send_to_scheduler.send_pyobj(obj)

978
    async def get_internal_state(self) -> List[Dict[Any, Any]]:
979
        req = GetInternalStateReq()
980
        responses: List[GetInternalStateReqOutput] = (
981
982
            await self.get_internal_state_communicator(req)
        )
983
984
        # Many DP ranks
        return [res.internal_state for res in responses]
985

986
987
988
989
990
991
992
993
    async def set_internal_state(
        self, obj: SetInternalStateReq
    ) -> SetInternalStateReqOutput:
        responses: List[SetInternalStateReqOutput] = (
            await self.set_internal_state_communicator(obj)
        )
        return [res.internal_state for res in responses]

994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
    def get_log_request_metadata(self):
        max_length = None
        skip_names = None
        out_skip_names = None
        if self.log_requests:
            if self.log_requests_level == 0:
                max_length = 1 << 30
                skip_names = set(
                    [
                        "text",
                        "input_ids",
                        "input_embeds",
                        "image_data",
                        "audio_data",
                        "lora_path",
                    ]
                )
                out_skip_names = set(
                    [
                        "text",
                        "output_ids",
                    ]
                )
            elif self.log_requests_level == 1:
                max_length = 2048
            elif self.log_requests_level == 2:
                max_length = 1 << 30
            else:
                raise ValueError(
                    f"Invalid --log-requests-level: {self.log_requests_level=}"
                )
        return max_length, skip_names, out_skip_names

1027
    def configure_logging(self, obj: ConfigureLoggingReq):
1028
1029
1030
1031
1032
1033
1034
1035
1036
        if obj.log_requests is not None:
            self.log_requests = obj.log_requests
        if obj.log_requests_level is not None:
            self.log_requests_level = obj.log_requests_level
        if obj.dump_requests_folder is not None:
            self.dump_requests_folder = obj.dump_requests_folder
        if obj.dump_requests_threshold is not None:
            self.dump_requests_threshold = obj.dump_requests_threshold
        logging.info(f"Config logging: {obj=}")
1037
        self.log_request_metadata = self.get_log_request_metadata()
1038

Lianmin Zheng's avatar
Lianmin Zheng committed
1039
    def create_abort_task(self, obj: GenerateReqInput):
1040
1041
        # Abort the request if the client is disconnected.
        async def abort_request():
Lianmin Zheng's avatar
Lianmin Zheng committed
1042
            await asyncio.sleep(2)
1043
1044
1045
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
1046
                for rid in obj.rid:
1047
1048
1049
1050
1051
1052
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

1053
    def auto_create_handle_loop(self):
1054
        if self.no_create_loop:
1055
1056
            return

1057
        self.no_create_loop = True
Lianmin Zheng's avatar
Lianmin Zheng committed
1058
        loop = asyncio.get_event_loop()
1059
1060
1061
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.handle_loop))
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1062

1063
1064
        self.event_loop = loop

1065
1066
1067
1068
        # We cannot add signal handler when the tokenizer manager is not in
        # the main thread due to the CPython limitation.
        if threading.current_thread() is threading.main_thread():
            signal_handler = SignalHandler(self)
1069
1070
1071
1072
1073
            loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
            # Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
            loop.add_signal_handler(
                signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
            )
1074
1075
1076
1077
1078
1079
        else:
            logger.warning(
                "Signal handler is not added because the tokenizer manager is "
                "not in the main thread. This disables graceful shutdown of the "
                "tokenizer manager when SIGTERM is received."
            )
1080
1081
1082
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
        )
1083
1084
1085

    async def sigterm_watchdog(self):
        while not self.gracefully_exit:
1086
            await asyncio.sleep(5)
1087

1088
        # Drain requests
1089
        while True:
1090
            remain_num_req = len(self.rid_to_state)
1091
1092
1093
1094
1095
1096
1097
1098
1099

            if self.health_check_failed:
                # if health check failed, we should exit immediately
                logger.error(
                    "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
                    remain_num_req,
                )
                break

1100
            logger.info(
1101
                f"Gracefully exiting... remaining number of requests {remain_num_req}"
1102
1103
1104
1105
1106
1107
            )
            if remain_num_req > 0:
                await asyncio.sleep(5)
            else:
                break

1108
        kill_process_tree(os.getpid(), include_parent=True)
1109
        sys.exit(0)
1110

Lianmin Zheng's avatar
Lianmin Zheng committed
1111
    async def handle_loop(self):
1112
1113
        """The event loop that handles requests"""

Lianmin Zheng's avatar
Lianmin Zheng committed
1114
        while True:
1115
            recv_obj = await self.recv_from_detokenizer.recv_pyobj()
1116
            self._result_dispatcher(recv_obj)
1117
            self.last_receive_tstamp = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
1118

1119
    def _handle_batch_output(
1120
1121
1122
1123
        self,
        recv_obj: Union[
            BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
        ],
1124
1125
1126
1127
    ):
        for i, rid in enumerate(recv_obj.rids):
            state = self.rid_to_state.get(rid, None)
            if state is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1128
1129
1130
                logger.error(
                    f"Received output for {rid=} but the state was deleted in TokenizerManager."
                )
1131
1132
                continue

1133
            # Build meta_info and return value
1134
1135
1136
1137
1138
1139
1140
1141
1142
            meta_info = {
                "id": rid,
                "finish_reason": recv_obj.finished_reasons[i],
                "prompt_tokens": recv_obj.prompt_tokens[i],
            }

            if getattr(state.obj, "return_logprob", False):
                self.convert_logprob_style(
                    meta_info,
1143
                    state,
1144
                    state.obj.top_logprobs_num,
1145
                    state.obj.token_ids_logprob,
1146
1147
                    state.obj.return_text_in_logprobs
                    and not self.server_args.skip_tokenizer_init,
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
                    recv_obj,
                    i,
                )

            if not isinstance(recv_obj, BatchEmbeddingOut):
                meta_info.update(
                    {
                        "completion_tokens": recv_obj.completion_tokens[i],
                        "cached_tokens": recv_obj.cached_tokens[i],
                    }
                )

1160
            if getattr(recv_obj, "output_hidden_states", None):
1161
1162
1163
                meta_info["hidden_states"] = recv_obj.output_hidden_states[i]

            if isinstance(recv_obj, BatchStrOut):
1164
                state.text += recv_obj.output_strs[i]
1165
                out_dict = {
1166
                    "text": state.text,
1167
1168
1169
                    "meta_info": meta_info,
                }
            elif isinstance(recv_obj, BatchTokenIDOut):
1170
                if self.server_args.stream_output and state.obj.stream:
1171
1172
1173
                    state.output_ids.extend(recv_obj.output_ids[i])
                    output_token_ids = state.output_ids[state.last_output_offset :]
                    state.last_output_offset = len(state.output_ids)
1174
                else:
1175
1176
                    state.output_ids.extend(recv_obj.output_ids[i])
                    output_token_ids = state.output_ids
1177

1178
                out_dict = {
1179
                    "output_ids": output_token_ids,
1180
1181
                    "meta_info": meta_info,
                }
1182
            elif isinstance(recv_obj, BatchMultimodalOut):
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
                if isinstance(recv_obj.outputs[i], str):
                    out_dict = {
                        "text": recv_obj.outputs[i],
                        "meta_info": meta_info,
                    }
                else:
                    out_dict = {
                        "outputs": json.dumps(recv_obj.outputs[i]),
                        "meta_info": meta_info,
                    }
1193
1194
1195
1196
1197
1198
1199
1200
            else:
                assert isinstance(recv_obj, BatchEmbeddingOut)
                out_dict = {
                    "embedding": recv_obj.embeddings[i],
                    "meta_info": meta_info,
                }

            state.finished = recv_obj.finished_reasons[i] is not None
1201
1202
1203
1204
1205
            if state.finished:
                if self.server_args.speculative_algorithm:
                    meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
                state.finished_time = time.time()
                meta_info["e2e_latency"] = state.finished_time - state.created_time
Lianmin Zheng's avatar
Lianmin Zheng committed
1206
                del self.rid_to_state[rid]
1207
1208

            state.out_list.append(out_dict)
1209
1210
            state.event.set()

1211
            # Log metrics and dump
1212
1213
1214
1215
1216
1217
1218
1219
            if self.enable_metrics and state.obj.log_metrics:
                self.collect_metrics(state, recv_obj, i)
            if self.dump_requests_folder and state.finished and state.obj.log_metrics:
                self.dump_requests(state, out_dict)

    def convert_logprob_style(
        self,
        meta_info: dict,
1220
        state: ReqState,
1221
        top_logprobs_num: int,
1222
        token_ids_logprob: List[int],
1223
1224
1225
1226
        return_text_in_logprobs: bool,
        recv_obj: BatchStrOut,
        recv_obj_index: int,
    ):
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
        if len(recv_obj.input_token_logprobs_val) > 0:
            state.input_token_logprobs_val.extend(
                recv_obj.input_token_logprobs_val[recv_obj_index]
            )
            state.input_token_logprobs_idx.extend(
                recv_obj.input_token_logprobs_idx[recv_obj_index]
            )
        state.output_token_logprobs_val.extend(
            recv_obj.output_token_logprobs_val[recv_obj_index]
        )
        state.output_token_logprobs_idx.extend(
            recv_obj.output_token_logprobs_idx[recv_obj_index]
        )
1240
        meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
1241
1242
            state.input_token_logprobs_val,
            state.input_token_logprobs_idx,
1243
1244
1245
            return_text_in_logprobs,
        )
        meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
1246
1247
            state.output_token_logprobs_val,
            state.output_token_logprobs_idx,
1248
1249
1250
1251
            return_text_in_logprobs,
        )

        if top_logprobs_num > 0:
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
            if len(recv_obj.input_top_logprobs_val) > 0:
                state.input_top_logprobs_val.extend(
                    recv_obj.input_top_logprobs_val[recv_obj_index]
                )
                state.input_top_logprobs_idx.extend(
                    recv_obj.input_top_logprobs_idx[recv_obj_index]
                )
            state.output_top_logprobs_val.extend(
                recv_obj.output_top_logprobs_val[recv_obj_index]
            )
            state.output_top_logprobs_idx.extend(
                recv_obj.output_top_logprobs_idx[recv_obj_index]
            )
1265
            meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1266
1267
                state.input_top_logprobs_val,
                state.input_top_logprobs_idx,
1268
1269
1270
                return_text_in_logprobs,
            )
            meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1271
1272
                state.output_top_logprobs_val,
                state.output_top_logprobs_idx,
1273
1274
1275
                return_text_in_logprobs,
            )

1276
        if token_ids_logprob is not None:
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
            if len(recv_obj.input_token_ids_logprobs_val) > 0:
                state.input_token_ids_logprobs_val.extend(
                    recv_obj.input_token_ids_logprobs_val[recv_obj_index]
                )
                state.input_token_ids_logprobs_idx.extend(
                    recv_obj.input_token_ids_logprobs_idx[recv_obj_index]
                )
            state.output_token_ids_logprobs_val.extend(
                recv_obj.output_token_ids_logprobs_val[recv_obj_index]
            )
            state.output_token_ids_logprobs_idx.extend(
                recv_obj.output_token_ids_logprobs_idx[recv_obj_index]
            )
1290
            meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
1291
1292
                state.input_token_ids_logprobs_val,
                state.input_token_ids_logprobs_idx,
1293
1294
1295
1296
                return_text_in_logprobs,
            )
            meta_info["output_token_ids_logprobs"] = (
                self.detokenize_top_logprobs_tokens(
1297
1298
                    state.output_token_ids_logprobs_val,
                    state.output_token_ids_logprobs_idx,
1299
1300
1301
1302
                    return_text_in_logprobs,
                )
            )

1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
    def detokenize_logprob_tokens(
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
    ):
        if not decode_to_text:
            return [
                (logprob, token_id, None)
                for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
            ]
        else:
            assert self.tokenizer is not None
            token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
            return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))

    def detokenize_top_logprobs_tokens(
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
    ):
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
        ret = []
        for i in range(len(token_logprobs_val)):
            if token_logprobs_val[i]:
                ret.append(
                    self.detokenize_logprob_tokens(
                        token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
                    )
                )
            else:
                ret.append(None)
        return ret

    def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
        completion_tokens = (
            recv_obj.completion_tokens[i]
            if getattr(recv_obj, "completion_tokens", None)
            else 0
        )

1346
1347
1348
        if state.first_token_time == 0.0:
            state.first_token_time = state.last_time = time.time()
            state.last_completion_tokens = completion_tokens
1349
1350
1351
1352
            self.metrics_collector.observe_time_to_first_token(
                state.first_token_time - state.created_time
            )
        else:
1353
1354
1355
1356
1357
1358
1359
            num_new_tokens = completion_tokens - state.last_completion_tokens
            if num_new_tokens:
                new_time = time.time()
                interval = new_time - state.last_time
                self.metrics_collector.observe_inter_token_latency(
                    interval,
                    num_new_tokens,
1360
                )
1361
1362
                state.last_time = new_time
                state.last_completion_tokens = completion_tokens
1363
1364

        if state.finished:
1365
1366
1367
1368
1369
1370
            has_grammar = (
                state.obj.sampling_params.get("json_schema", None)
                or state.obj.sampling_params.get("regex", None)
                or state.obj.sampling_params.get("ebnf", None)
                or state.obj.sampling_params.get("structural_tag", None)
            )
1371
            self.metrics_collector.observe_one_finished_request(
1372
1373
                recv_obj.prompt_tokens[i],
                completion_tokens,
1374
                recv_obj.cached_tokens[i],
1375
                state.finished_time - state.created_time,
1376
                has_grammar,
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
            )

    def dump_requests(self, state: ReqState, out_dict: dict):
        self.dump_request_list.append(
            (state.obj, out_dict, state.created_time, time.time())
        )

        if len(self.dump_request_list) >= self.dump_requests_threshold:
            filename = os.path.join(
                self.dump_requests_folder,
                datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
            )
            logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")

            to_dump = self.dump_request_list
            self.dump_request_list = []

            def background_task():
                os.makedirs(self.dump_requests_folder, exist_ok=True)
                with open(filename, "wb") as f:
                    pickle.dump(to_dump, f)

            # Schedule the task to run in the background without awaiting it
            asyncio.create_task(asyncio.to_thread(background_task))

Lianmin Zheng's avatar
Lianmin Zheng committed
1402
1403
1404
    def _handle_abort_req(self, recv_obj):
        self.rid_to_state.pop(recv_obj.rid)

1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
    def _handle_open_session_req_output(self, recv_obj):
        self.session_futures[recv_obj.session_id].set_result(
            recv_obj.session_id if recv_obj.success else None
        )

    def _handle_update_weights_from_disk_req_output(self, recv_obj):
        if self.server_args.dp_size == 1:
            self.model_update_result.set_result(recv_obj)
        else:  # self.server_args.dp_size > 1
            self.model_update_tmp.append(recv_obj)
1415
            # set future if the all results are received
1416
1417
1418
            if len(self.model_update_tmp) == self.server_args.dp_size:
                self.model_update_result.set_result(self.model_update_tmp)

1419

1420
1421
1422
1423
1424
1425
1426
1427
1428
async def print_exception_wrapper(func):
    """
    Sometimes an asyncio function does not print exception.
    We do another wrapper to handle the exception.
    """
    try:
        await func()
    except Exception:
        traceback = get_exception_traceback()
1429
        logger.error(f"TokenizerManager hit an exception: {traceback}")
1430
1431
1432
1433
        kill_process_tree(os.getpid(), include_parent=True)
        sys.exit(1)


1434
class SignalHandler:
1435
    def __init__(self, tokenizer_manager: TokenizerManager):
1436
        self.tokenizer_manager = tokenizer_manager
1437

1438
    def sigterm_handler(self, signum=None, frame=None):
1439
1440
1441
        logger.warning(
            f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
        )
1442
        self.tokenizer_manager.gracefully_exit = True
1443

1444
1445
1446
1447
1448
1449
    def running_phase_sigquit_handler(self, signum=None, frame=None):
        logger.error(
            "Received sigquit from a child process. It usually means the child failed."
        )
        kill_process_tree(os.getpid())

1450
1451
1452
1453
1454

T = TypeVar("T")


class _Communicator(Generic[T]):
1455
1456
    """Note: The communicator now only run up to 1 in-flight request at any time."""

1457
1458
1459
    def __init__(self, sender, fan_out: int):
        self._sender = sender
        self._fan_out = fan_out
1460
        self._result_event: Optional[asyncio.Event] = None
1461
        self._result_values: Optional[List[T]] = None
1462
        self._ready_queue: Deque[asyncio.Future] = deque()
1463
1464

    async def __call__(self, obj):
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
        ready_event = asyncio.Event()
        if self._result_event is not None or len(self._ready_queue) > 0:
            self._ready_queue.append(ready_event)
            await ready_event.wait()
            assert self._result_event is None
            assert self._result_values is None

        if obj:
            self._sender.send_pyobj(obj)

        self._result_event = asyncio.Event()
1476
        self._result_values = []
1477
        await self._result_event.wait()
1478
        result_values = self._result_values
1479
1480
1481
1482
1483
        self._result_event = self._result_values = None

        if len(self._ready_queue) > 0:
            self._ready_queue.popleft().set()

1484
1485
1486
1487
1488
        return result_values

    def handle_recv(self, recv_obj: T):
        self._result_values.append(recv_obj)
        if len(self._result_values) == self._fan_out:
1489
            self._result_event.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501


# Note: request abort handling logic
# We should handle all of the following cases correctly.
#
# | entrypoint | is_streaming | status          | abort engine    | cancel asyncio task   | rid_to_state                |
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
# | http       | yes          | waiting queue   | background task | fast api              | del in _handle_abort_req    |
# | http       | yes          | running         | background task | fast api              | del in _handle_batch_output |
# | http       | no           | waiting queue   | type 1          | type 1 exception      | del in _handle_abort_req    |
# | http       | no           | running         | type 3          | type 3 exception      | del in _handle_batch_output |
#