tokenizer_manager.py 69.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
"""TokenizerManager is a process that tokenizes the text."""
15

Lianmin Zheng's avatar
Lianmin Zheng committed
16
import asyncio
17
18
import copy
import dataclasses
19
import json
20
import logging
21
import math
Lianmin Zheng's avatar
Lianmin Zheng committed
22
import os
23
import pickle
24
25
import signal
import sys
26
import threading
27
import time
28
import uuid
29
from collections import deque
30
31
from datetime import datetime
from http import HTTPStatus
32
33
34
35
36
37
38
39
40
41
42
43
from typing import (
    Any,
    Awaitable,
    Deque,
    Dict,
    Generic,
    List,
    Optional,
    Tuple,
    TypeVar,
    Union,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
44

45
import fastapi
46
import torch
Lianmin Zheng's avatar
Lianmin Zheng committed
47
48
49
import uvloop
import zmq
import zmq.asyncio
50
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
51

52
from sglang.srt.aio_rwlock import RWLock
53
from sglang.srt.configs.model_config import ModelConfig
54
55
56
57
58
59
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
    KVClassType,
    TransferBackend,
    get_kv_class,
)
xm:D's avatar
xm:D committed
60
61
62
63
64
from sglang.srt.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
65
from sglang.srt.managers.io_struct import (
66
    AbortReq,
67
    BatchEmbeddingOut,
68
    BatchMultimodalOut,
Lianmin Zheng's avatar
Lianmin Zheng committed
69
    BatchStrOut,
70
    BatchTokenIDOut,
71
    CloseSessionReqInput,
72
    ConfigureLoggingReq,
73
    EmbeddingReqInput,
74
    ExpertDistributionReq,
75
    ExpertDistributionReqOutput,
76
77
    FlushCacheReqInput,
    FlushCacheReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
78
    GenerateReqInput,
79
80
    GetInternalStateReq,
    GetInternalStateReqOutput,
81
82
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
83
    HealthCheckOutput,
84
85
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
86
87
88
    LoadLoRAAdapterReqInput,
    LoadLoRAAdapterReqOutput,
    LoRAUpdateResult,
89
90
    OpenSessionReqInput,
    OpenSessionReqOutput,
91
    ProfileReq,
92
93
    ProfileReqOutput,
    ProfileReqType,
94
95
96
97
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
98
    SessionParams,
99
100
    SetInternalStateReq,
    SetInternalStateReqOutput,
101
102
    SlowDownReqInput,
    SlowDownReqOutput,
103
104
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
105
106
    UnloadLoRAAdapterReqInput,
    UnloadLoRAAdapterReqOutput,
Chayenne's avatar
Chayenne committed
107
108
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
109
110
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
111
112
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
113
)
114
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
115
116
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
117
from sglang.srt.server_args import PortArgs, ServerArgs
118
119
from sglang.srt.utils import (
    dataclass_to_string_truncated,
120
    get_bool_env_var,
121
122
123
    get_zmq_socket,
    kill_process_tree,
)
124
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
Lianmin Zheng's avatar
Lianmin Zheng committed
125
126
127

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

128
129
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
130

131
132
133
134
@dataclasses.dataclass
class ReqState:
    """Store the state a request."""

135
    out_list: List[Dict[Any, Any]]
136
137
    finished: bool
    event: asyncio.Event
138
    obj: Union[GenerateReqInput, EmbeddingReqInput]
139
140
141

    # For metrics
    created_time: float
142
143
144
145
    finished_time: float = 0.0
    first_token_time: float = 0.0
    last_time: float = 0.0
    last_completion_tokens: int = 1
146
147
148

    # For streaming output
    last_output_offset: int = 0
149

150
    # For incremental state update.
151
    # TODO(lianmin): do not initialize some lists if not needed.
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    text: str = ""
    output_ids: List[int] = dataclasses.field(default_factory=list)
    input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
    input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
    output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
    output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
    input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
    input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
    output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
    output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
    input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
    input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
    output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
    output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
166
167


168
169
class TokenizerManager:
    """TokenizerManager is a process that tokenizes the text."""
170

Lianmin Zheng's avatar
Lianmin Zheng committed
171
172
173
174
175
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
176
        # Parse args
Liangsheng Yin's avatar
Liangsheng Yin committed
177
        self.server_args = server_args
178
        self.enable_metrics = server_args.enable_metrics
179
        self.log_requests = server_args.log_requests
180
        self.log_requests_level = server_args.log_requests_level
181
182
183
184
185
        self.preferred_sampling_params = (
            json.loads(server_args.preferred_sampling_params)
            if server_args.preferred_sampling_params
            else None
        )
186
187
        self.crash_dump_folder = server_args.crash_dump_folder
        self.crash_dump_performed = False  # Flag to ensure dump is only called once
Liangsheng Yin's avatar
Liangsheng Yin committed
188

189
        # Init inter-process communication
Lianmin Zheng's avatar
Lianmin Zheng committed
190
        context = zmq.asyncio.Context(2)
191
        self.recv_from_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
192
            context, zmq.PULL, port_args.tokenizer_ipc_name, True
193
194
        )
        self.send_to_scheduler = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
195
            context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
196
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
197

198
        # Read model args
Lianmin Zheng's avatar
Lianmin Zheng committed
199
        self.model_path = server_args.model_path
200
        self.served_model_name = server_args.served_model_name
201
        self.model_config = ModelConfig.from_server_args(server_args)
202
        self.is_generation = self.model_config.is_generation
203
        self.is_image_gen = self.model_config.is_image_gen
204
205
206
        self.context_len = self.model_config.context_len
        self.image_token_id = self.model_config.image_token_id

207
        if self.model_config.is_multimodal:
Mick's avatar
Mick committed
208
            import_processors()
209
210
211
212
213
            _processor = get_processor(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
                revision=server_args.revision,
214
                use_fast=not server_args.disable_fast_image_processor,
215
216
217
            )

            # We want to parallelize the image pre-processing so we create an executor for it
Mick's avatar
Mick committed
218
            # We create mm_processor for any skip_tokenizer_init to make sure we still encode
219
            # images even with skip_tokenizer_init=False.
Mick's avatar
Mick committed
220
            self.mm_processor = get_mm_processor(
221
222
223
224
225
226
227
                self.model_config.hf_config, server_args, _processor
            )

            if server_args.skip_tokenizer_init:
                self.tokenizer = self.processor = None
            else:
                self.processor = _processor
xm:D's avatar
xm:D committed
228
                self.tokenizer = get_tokenizer_from_processor(self.processor)
229
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
230
        else:
231
            self.mm_processor = None
232

233
234
            if server_args.skip_tokenizer_init:
                self.tokenizer = self.processor = None
235
236
237
238
239
240
241
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
242

243
        # Store states
244
        self.no_create_loop = False
245
        self.rid_to_state: Dict[str, ReqState] = {}
246
        self.health_check_failed = False
247
248
        self.gracefully_exit = False
        self.last_receive_tstamp = 0
249
250
251
        self.dump_requests_folder = ""  # By default do not dump
        self.dump_requests_threshold = 1000
        self.dump_request_list: List[Tuple] = []
252
        self.crash_dump_request_list: deque[Tuple] = deque()
253
        self.log_request_metadata = self.get_log_request_metadata()
254
255
        self.session_futures = {}  # session_id -> asyncio event
        self.max_req_input_len = None
256
        self.asyncio_tasks = set()
Lianmin Zheng's avatar
Lianmin Zheng committed
257

258
259
260
261
262
        # The event to notify the weight sync is finished.
        self.model_update_lock = RWLock()
        self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
            None
        )
263

264
265
266
267
        # For pd disaggregtion
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
268
        self.disaggregation_transfer_backend = TransferBackend(
269
270
271
272
273
274
            self.server_args.disaggregation_transfer_backend
        )
        # Start kv boostrap server on prefill
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            # only start bootstrap server on prefill tm
            kv_bootstrap_server_class = get_kv_class(
275
                self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
276
277
278
279
            )
            self.bootstrap_server = kv_bootstrap_server_class(
                self.server_args.disaggregation_bootstrap_port
            )
280

281
282
283
        # For load balancing
        self.current_load = 0
        self.current_load_lock = asyncio.Lock()
284
285
286
287
288
289

        # Metrics
        if self.enable_metrics:
            self.metrics_collector = TokenizerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
290
                    # TODO: Add lora name/path in the future,
291
                },
292
293
294
295
                bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
                bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
                bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
                collect_tokens_histogram=self.server_args.collect_tokens_histogram,
296
297
298
            )

        # Communicators
299
300
301
302
303
304
        self.init_weights_update_group_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.update_weights_from_distributed_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
305
306
307
        self.update_weights_from_tensor_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
308
309
310
        self.get_weights_by_name_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
311
312
313
314
315
316
        self.release_memory_occupation_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.resume_memory_occupation_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
317
318
319
        self.slow_down_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
320
321
322
        self.flush_cache_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
323
        self.profile_communicator = _Communicator(
324
325
326
327
328
            self.send_to_scheduler, server_args.dp_size
        )
        self.get_internal_state_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
329
330
331
        self.set_internal_state_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
332
333
334
        self.expert_distribution_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
335
336
337
        self.update_lora_adapter_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
338

339
        self._result_dispatcher = TypeBasedDispatcher(
340
            [
341
                (
342
343
344
345
346
347
                    (
                        BatchStrOut,
                        BatchEmbeddingOut,
                        BatchTokenIDOut,
                        BatchMultimodalOut,
                    ),
348
                    self._handle_batch_output,
349
                ),
Lianmin Zheng's avatar
Lianmin Zheng committed
350
                (AbortReq, self._handle_abort_req),
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
                (OpenSessionReqOutput, self._handle_open_session_req_output),
                (
                    UpdateWeightFromDiskReqOutput,
                    self._handle_update_weights_from_disk_req_output,
                ),
                (
                    InitWeightsUpdateGroupReqOutput,
                    self.init_weights_update_group_communicator.handle_recv,
                ),
                (
                    UpdateWeightsFromDistributedReqOutput,
                    self.update_weights_from_distributed_communicator.handle_recv,
                ),
                (
                    UpdateWeightsFromTensorReqOutput,
                    self.update_weights_from_tensor_communicator.handle_recv,
                ),
                (
                    GetWeightsByNameReqOutput,
                    self.get_weights_by_name_communicator.handle_recv,
                ),
                (
                    ReleaseMemoryOccupationReqOutput,
                    self.release_memory_occupation_communicator.handle_recv,
                ),
                (
                    ResumeMemoryOccupationReqOutput,
                    self.resume_memory_occupation_communicator.handle_recv,
                ),
380
381
382
383
                (
                    SlowDownReqOutput,
                    self.slow_down_communicator.handle_recv,
                ),
384
385
386
387
                (
                    FlushCacheReqOutput,
                    self.flush_cache_communicator.handle_recv,
                ),
388
389
                (
                    ProfileReqOutput,
390
                    self.profile_communicator.handle_recv,
391
392
393
394
395
                ),
                (
                    GetInternalStateReqOutput,
                    self.get_internal_state_communicator.handle_recv,
                ),
396
397
398
399
                (
                    SetInternalStateReqOutput,
                    self.set_internal_state_communicator.handle_recv,
                ),
400
401
402
403
                (
                    ExpertDistributionReqOutput,
                    self.expert_distribution_communicator.handle_recv,
                ),
404
405
406
407
                (
                    LoRAUpdateResult,
                    self.update_lora_adapter_communicator.handle_recv,
                ),
408
                (HealthCheckOutput, lambda x: None),
409
410
411
            ]
        )

412
    async def generate_request(
413
        self,
414
        obj: Union[GenerateReqInput, EmbeddingReqInput],
415
        request: Optional[fastapi.Request] = None,
416
    ):
417
        created_time = time.time()
418
        self.auto_create_handle_loop()
419
        obj.normalize_batch_and_arguments()
420
421
422
423
424
425
426
427

        if isinstance(obj, EmbeddingReqInput) and self.is_generation:
            raise ValueError(
                "This model does not appear to be an embedding model by default. "
                "Please add `--is-embedding` when launching the server or try another model."
            )

        if self.log_requests:
428
            max_length, skip_names, _ = self.log_request_metadata
429
            logger.info(
430
                f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
431
432
            )

433
        async with self.model_update_lock.reader_lock:
434
            if obj.is_single:
435
                tokenized_obj = await self._tokenize_one_request(obj)
436
437
                state = self._send_one_request(obj, tokenized_obj, created_time)
                async for response in self._wait_one_response(obj, state, request):
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
                    yield response
            else:
                async for response in self._handle_batch_request(
                    obj, request, created_time
                ):
                    yield response

    async def _tokenize_one_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
    ):
        """Tokenize one request."""
        # Tokenize
        input_embeds = None
        input_text = obj.text
woodx's avatar
woodx committed
453
454
455
456
        token_type_ids = None
        is_cross_encoder_request = (
            isinstance(obj, EmbeddingReqInput) and obj.is_cross_encoder_request
        )
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
        if obj.input_embeds is not None:
            if not self.server_args.disable_radix_cache:
                raise ValueError(
                    "input_embeds is provided while disable_radix_cache is False. "
                    "Please add `--disable-radix-cache` when you launch the server "
                    "if you want to use input_embeds as inputs."
                )
            input_embeds = obj.input_embeds
            input_ids = obj.input_ids
        elif obj.input_ids is not None:
            input_ids = obj.input_ids
        else:
            if self.tokenizer is None:
                raise ValueError(
                    "The engine initialized with skip_tokenizer_init=True cannot "
                    "accept text prompts. Please provide input_ids or re-initialize "
                    "the engine with skip_tokenizer_init=False."
                )
woodx's avatar
woodx committed
475
476
477
478
479
480
481
482
            encoded = self.tokenizer(
                input_text, return_token_type_ids=is_cross_encoder_request
            )

            input_ids = encoded["input_ids"]
            if is_cross_encoder_request:
                input_ids = encoded["input_ids"][0]
                token_type_ids = encoded.get("token_type_ids", [None])[0]
483

484
        if self.mm_processor and obj.contains_mm_input():
485
            image_inputs: Dict = await self.mm_processor.process_mm_data_async(
486
487
488
489
490
491
492
                image_data=obj.image_data,
                input_text=input_text or input_ids,
                request_obj=obj,
                max_req_input_len=self.max_req_input_len,
            )
            if image_inputs and "input_ids" in image_inputs:
                input_ids = image_inputs["input_ids"]
493
494
        else:
            image_inputs: Optional[Dict] = None
495

496
        self._validate_one_request(obj, input_ids)
497
        return self._create_tokenized_object(
woodx's avatar
woodx committed
498
            obj, input_text, input_ids, input_embeds, image_inputs, token_type_ids
499
500
        )

501
    def _validate_one_request(
502
503
504
        self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
    ) -> None:
        """Validates that the input token count and the requested token count doesn't exceed the model's context length."""
505
506

        input_token_num = len(input_ids) if input_ids is not None else 0
507
        # Check if input alone exceeds context length
508
509
510
511
512
513
        if input_token_num >= self.context_len:
            raise ValueError(
                f"The input ({input_token_num} tokens) is longer than the "
                f"model's context length ({self.context_len} tokens)."
            )

514
515
        # Check total tokens (input + max_new_tokens)
        max_new_tokens = obj.sampling_params.get("max_new_tokens")
516
        if (
517
518
            max_new_tokens is not None
            and (max_new_tokens + input_token_num) >= self.context_len
519
        ):
520
521
            total_tokens = max_new_tokens + input_token_num
            error_msg = (
522
                f"Requested token count exceeds the model's maximum context length "
523
                f"of {self.context_len} tokens. You requested a total of {total_tokens} "
524
                f"tokens: {input_token_num} tokens from the input messages and "
525
526
527
528
529
                f"{max_new_tokens} tokens for the completion. Please reduce the number "
                f"of tokens in the input messages or the completion to fit within the limit."
            )
            raise ValueError(error_msg)

530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
        if isinstance(obj, GenerateReqInput):
            if (
                obj.return_hidden_states
                and not self.server_args.enable_return_hidden_states
            ):
                raise ValueError(
                    "The server is not configured to return the hidden states. "
                    "Please set `--enable-return-hidden-states` to enable this feature."
                )
            if (
                obj.custom_logit_processor
                and not self.server_args.enable_custom_logit_processor
            ):
                raise ValueError(
                    "The server is not configured to enable custom logit processor. "
                    "Please set `--enable-custom-logits-processor` to enable this feature."
                )

548
549
550
551
552
553
554
555
    def _validate_input_ids_in_vocab(
        self, input_ids: List[int], vocab_size: int
    ) -> None:
        if any(id >= vocab_size for id in input_ids):
            raise ValueError(
                f"The input_ids {input_ids} contains values greater than the vocab size ({vocab_size})."
            )

556
557
558
559
560
561
562
    def _create_tokenized_object(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        input_text: str,
        input_ids: List[int],
        input_embeds: Optional[Union[List[float], None]] = None,
        image_inputs: Optional[Dict] = None,
woodx's avatar
woodx committed
563
        token_type_ids: Optional[List[int]] = None,
564
565
    ) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
        """Create a tokenized request object from common parameters."""
566
567
568
569
570
571
572
573
        # Parse sampling parameters
        # Note: if there are preferred sampling params, we use them if they are not
        # explicitly passed in sampling_params
        if self.preferred_sampling_params:
            sampling_kwargs = {**self.preferred_sampling_params, **obj.sampling_params}
        else:
            sampling_kwargs = obj.sampling_params
        sampling_params = SamplingParams(**sampling_kwargs)
574
575
576
577
578
        sampling_params.normalize(self.tokenizer)
        sampling_params.verify()

        # Build return object
        if isinstance(obj, GenerateReqInput):
579
580
581
582
            session_params = (
                SessionParams(**obj.session_params) if obj.session_params else None
            )

583
584
585
586
587
588
            tokenized_obj = TokenizedGenerateReqInput(
                obj.rid,
                input_text,
                input_ids,
                image_inputs,
                sampling_params,
589
590
591
592
                obj.return_logprob,
                obj.logprob_start_len,
                obj.top_logprobs_num,
                obj.token_ids_logprob,
593
                obj.stream,
594
                bootstrap_host=obj.bootstrap_host,
595
                bootstrap_port=obj.bootstrap_port,
596
                bootstrap_room=obj.bootstrap_room,
597
598
599
600
                lora_path=obj.lora_path,
                input_embeds=input_embeds,
                session_params=session_params,
                custom_logit_processor=obj.custom_logit_processor,
601
                return_hidden_states=obj.return_hidden_states,
602
                data_parallel_rank=obj.data_parallel_rank,
603
604
605
606
607
608
            )
        elif isinstance(obj, EmbeddingReqInput):
            tokenized_obj = TokenizedEmbeddingReqInput(
                obj.rid,
                input_text,
                input_ids,
609
                image_inputs,
woodx's avatar
woodx committed
610
                token_type_ids,
611
612
613
614
615
                sampling_params,
            )

        return tokenized_obj

616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
    async def _batch_tokenize_and_process(
        self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
    ) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]:
        """Handle batch tokenization for text inputs only."""
        logger.debug(f"Starting batch tokenization for {batch_size} text requests")

        # Collect requests and texts
        requests = [obj[i] for i in range(batch_size)]
        texts = [req.text for req in requests]

        # Batch tokenize all texts
        encoded = self.tokenizer(texts)
        input_ids_list = encoded["input_ids"]

        # Process all requests
        tokenized_objs = []
        for i, req in enumerate(requests):
            self._validate_token_len(obj[i], input_ids_list[i])
            tokenized_objs.append(
                self._create_tokenized_object(
                    req, req.text, input_ids_list[i], None, None
                )
            )
        logger.debug(f"Completed batch processing for {batch_size} requests")
        return tokenized_objs

    def _validate_batch_tokenization_constraints(
        self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
    ) -> None:
        """Validate constraints for batch tokenization processing."""
        for i in range(batch_size):
            if self.is_generation and obj[i].image_data:
                raise ValueError(
                    "For image input processing do not set `enable_tokenizer_batch_encode`."
                )
            if obj[i].input_ids is not None:
                raise ValueError(
                    "Batch tokenization is not needed for pre-tokenized input_ids. Do not set `enable_tokenizer_batch_encode`."
                )
            if obj[i].input_embeds is not None:
                raise ValueError(
                    "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
                )

660
661
662
663
664
665
    def _send_one_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
        created_time: Optional[float] = None,
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
666
        self.send_to_scheduler.send_pyobj(tokenized_obj)
667
        state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
668
        self.rid_to_state[obj.rid] = state
669
        return state
670
671
672
673

    async def _wait_one_response(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
674
        state: ReqState,
675
676
677
678
679
680
681
682
        request: Optional[fastapi.Request] = None,
    ):
        """Wait for the response of one request."""
        while True:
            try:
                await asyncio.wait_for(state.event.wait(), timeout=4)
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
Lianmin Zheng's avatar
Lianmin Zheng committed
683
                    # Abort the request for disconnected requests (non-streaming, waiting queue)
684
                    self.abort_request(obj.rid)
Lianmin Zheng's avatar
Lianmin Zheng committed
685
                    # Use exception to kill the whole call stack and asyncio task
686
                    raise ValueError(
Lianmin Zheng's avatar
Lianmin Zheng committed
687
                        f"Request is disconnected from the client side (type 1). Abort request {obj.rid=}"
688
                    )
689
690
691
692
693
694
695
                continue

            out = state.out_list[-1]

            state.out_list = []
            if state.finished:
                if self.log_requests:
696
697
698
699
700
                    max_length, skip_names, out_skip_names = self.log_request_metadata
                    if self.model_config.is_multimodal_gen:
                        msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
                    else:
                        msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
                    logger.info(msg)

                # Check if this was an abort/error created by scheduler
                if isinstance(out["meta_info"].get("finish_reason"), dict):
                    finish_reason = out["meta_info"]["finish_reason"]
                    if (
                        finish_reason.get("type") == "abort"
                        and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST
                    ):
                        raise ValueError(finish_reason["message"])

                yield out
                break

            state.event.clear()

            if obj.stream:
                yield out
            else:
                if request is not None and await request.is_disconnected():
Lianmin Zheng's avatar
Lianmin Zheng committed
721
                    # Abort the request for disconnected requests (non-streaming, running)
722
                    self.abort_request(obj.rid)
Lianmin Zheng's avatar
Lianmin Zheng committed
723
                    # Use exception to kill the whole call stack and asyncio task
724
                    raise ValueError(
Lianmin Zheng's avatar
Lianmin Zheng committed
725
                        f"Request is disconnected from the client side (type 3). Abort request {obj.rid=}"
726
                    )
727
728
729
730
731
732
733
734
735
736
737
738

    async def _handle_batch_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
        created_time: Optional[float] = None,
    ):
        batch_size = obj.batch_size

        generators = []
        rids = []
        if getattr(obj, "parallel_sample_num", 1) == 1:
739
740
741
742
743
744
745
746
            if self.server_args.enable_tokenizer_batch_encode:
                # Validate batch tokenization constraints
                self._validate_batch_tokenization_constraints(batch_size, obj)

                tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)

                for i, tokenized_obj in enumerate(tokenized_objs):
                    tmp_obj = obj[i]
747
748
                    state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, state, request))
749
750
751
752
753
754
                    rids.append(tmp_obj.rid)
            else:
                # Sequential tokenization and processing
                for i in range(batch_size):
                    tmp_obj = obj[i]
                    tokenized_obj = await self._tokenize_one_request(tmp_obj)
755
756
                    state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, state, request))
757
                    rids.append(tmp_obj.rid)
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
        else:
            # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
            if batch_size > 128:
                logger.warning(
                    "Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
                    "The performance might be better if you just duplicate the requests n times or use "
                    "many threads to send them one by one with parallel sampling (n > 1)."
                )

            # Tokenize all requests
            objs = [obj[i] for i in range(batch_size)]
            tokenized_objs = await asyncio.gather(
                *(self._tokenize_one_request(obj) for obj in objs)
            )

            # Cache the common prefix for parallel sampling
            for i in range(batch_size):
                tmp_obj = copy.copy(objs[i])
                tokenized_obj = copy.copy(tokenized_objs[i])
                tokenized_obj.rid = tmp_obj.regenerate_rid()
                tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
                tokenized_obj.sampling_params.max_new_tokens = 0
                tokenized_obj.stream = False
781
782
                state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
                await self._wait_one_response(tmp_obj, state, request).__anext__()
783
784
785
786
787
788
789

            # Expand requests, assign new rids for them, and send them
            for i in range(batch_size):
                for _ in range(obj.parallel_sample_num):
                    tmp_obj = copy.copy(objs[i])
                    tokenized_obj = copy.copy(tokenized_objs[i])
                    tokenized_obj.rid = tmp_obj.regenerate_rid()
790
791
                    state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, state, request))
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
                    rids.append(tmp_obj.rid)

        # Wait for all requests
        is_stream = hasattr(obj, "stream") and obj.stream
        if not is_stream:
            outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
            yield outputs
        else:
            rid_to_index = {rid: i for i, rid in enumerate(rids)}
            task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
            while task_map:
                done, _ = await asyncio.wait(
                    task_map.keys(), return_when=asyncio.FIRST_COMPLETED
                )

                for task in done:
                    gen = task_map.pop(task)
                    try:
                        result = task.result()
                        result["index"] = rid_to_index[result["meta_info"]["id"]]
                        yield result
                        new_task = asyncio.create_task(gen.__anext__())
                        task_map[new_task] = gen
                    except StopAsyncIteration:
                        pass
817

818
    async def flush_cache(self) -> FlushCacheReqOutput:
Lianmin Zheng's avatar
Lianmin Zheng committed
819
        return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
Liangsheng Yin's avatar
Liangsheng Yin committed
820

821
    def abort_request(self, rid: str):
822
823
824
825
        if rid not in self.rid_to_state:
            return
        req = AbortReq(rid)
        self.send_to_scheduler.send_pyobj(req)
826

827
828
829
        if self.enable_metrics:
            self.metrics_collector.observe_one_aborted_request()

830
831
832
833
834
    async def start_profile(
        self,
        output_dir: Optional[str] = None,
        num_steps: Optional[int] = None,
        activities: Optional[List[str]] = None,
835
836
        with_stack: Optional[bool] = None,
        record_shapes: Optional[bool] = None,
837
        profile_by_stage: bool = False,
838
    ):
839
        self.auto_create_handle_loop()
840
841
        env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
        with_stack = False if with_stack is False or env_with_stack is False else True
842
843
844
845
846
        req = ProfileReq(
            type=ProfileReqType.START_PROFILE,
            output_dir=output_dir,
            num_steps=num_steps,
            activities=activities,
847
848
            with_stack=with_stack,
            record_shapes=record_shapes,
849
            profile_by_stage=profile_by_stage,
850
            profile_id=str(time.time()),
851
        )
852
853
854
        return await self._execute_profile(req)

    async def stop_profile(self):
855
        self.auto_create_handle_loop()
856
857
858
859
860
        req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
        return await self._execute_profile(req)

    async def _execute_profile(self, req: ProfileReq):
        result = (await self.profile_communicator(req))[0]
861
862
863
        if not result.success:
            raise RuntimeError(result.message)
        return result
864

865
    async def start_expert_distribution_record(self):
866
        self.auto_create_handle_loop()
867
        await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
868

869
    async def stop_expert_distribution_record(self):
870
        self.auto_create_handle_loop()
871
        await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
872

873
    async def dump_expert_distribution_record(self):
874
        self.auto_create_handle_loop()
875
        await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
876

Chayenne's avatar
Chayenne committed
877
878
879
880
    async def update_weights_from_disk(
        self,
        obj: UpdateWeightFromDiskReqInput,
        request: Optional[fastapi.Request] = None,
881
    ) -> Tuple[bool, str]:
882
        self.auto_create_handle_loop()
883
884
885
886

        # default the load format to the server_args
        if obj.load_format is None:
            obj.load_format = self.server_args.load_format
887
        logger.info("Start update_weights. Load format=%s", obj.load_format)
888

889
        if True:  # Keep this redundant check to simplify some internal code sync
890
891
892
893
            # Hold the lock if it is not async. This means that weight sync
            # cannot run while requests are in progress.
            async with self.model_update_lock.writer_lock:
                return await self._wait_for_model_update_from_disk(obj)
894

895
896
    async def _wait_for_model_update_from_disk(
        self, obj: UpdateWeightFromDiskReqInput
897
    ) -> Tuple[bool, str]:
898
899
900
901
902
903
904
905
906
        self.send_to_scheduler.send_pyobj(obj)
        self.model_update_result = asyncio.Future()
        if self.server_args.dp_size == 1:
            result = await self.model_update_result
            if result.success:
                self.served_model_name = obj.model_path
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
907
            return result.success, result.message, result.num_paused_requests
908
909
910
911
912
913
914
915
916
917
918
        else:  # self.server_args.dp_size > 1
            self.model_update_tmp = []
            result = await self.model_update_result

            all_success = all([r.success for r in result])
            if all_success is True:
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
            all_message = [r.message for r in result]
            all_message = " | ".join(all_message)
919
920
            all_paused_requests = [r.num_paused_requests for r in result]
            return all_success, all_message, all_paused_requests
921

922
923
924
925
    async def init_weights_update_group(
        self,
        obj: InitWeightsUpdateGroupReqInput,
        request: Optional[fastapi.Request] = None,
926
    ) -> Tuple[bool, str]:
927
        self.auto_create_handle_loop()
928
929
930
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be 1 for init parameter update group"
931
        result = (await self.init_weights_update_group_communicator(obj))[0]
932
933
934
935
936
937
        return result.success, result.message

    async def update_weights_from_distributed(
        self,
        obj: UpdateWeightsFromDistributedReqInput,
        request: Optional[fastapi.Request] = None,
938
    ) -> Tuple[bool, str]:
939
940
        self.auto_create_handle_loop()
        assert (
941
942
            self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
        ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
943

944
945
946
        # This means that weight sync
        # cannot run while requests are in progress.
        async with self.model_update_lock.writer_lock:
947
            result = (await self.update_weights_from_distributed_communicator(obj))[0]
948
            return result.success, result.message
949

950
951
952
953
954
955
956
    async def update_weights_from_tensor(
        self,
        obj: UpdateWeightsFromTensorReqInput,
        request: Optional[fastapi.Request] = None,
    ) -> Tuple[bool, str]:
        self.auto_create_handle_loop()
        assert (
957
958
            self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
        ), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
959
960
961
962
963
964
965

        # This means that weight sync
        # cannot run while requests are in progress.
        async with self.model_update_lock.writer_lock:
            result = (await self.update_weights_from_tensor_communicator(obj))[0]
            return result.success, result.message

966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
    async def load_lora_adapter(
        self,
        obj: LoadLoRAAdapterReqInput,
        _: Optional[fastapi.Request] = None,
    ) -> LoadLoRAAdapterReqOutput:
        self.auto_create_handle_loop()

        # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
        # with dp_size > 1.
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be 1 for dynamic lora loading"
        logger.info(
            "Start load Lora adapter. Lora name=%s, path=%s",
            obj.lora_name,
            obj.lora_path,
        )

        async with self.model_update_lock.writer_lock:
            result = (await self.update_lora_adapter_communicator(obj))[0]
            return result

    async def unload_lora_adapter(
        self,
        obj: UnloadLoRAAdapterReqInput,
        _: Optional[fastapi.Request] = None,
    ) -> UnloadLoRAAdapterReqOutput:
        self.auto_create_handle_loop()

        # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
        # with dp_size > 1.
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be 1 for dynamic lora loading"
        logger.info(
            "Start unload Lora adapter. Lora name=%s",
            obj.lora_name,
        )

        async with self.model_update_lock.writer_lock:
            result = (await self.update_lora_adapter_communicator(obj))[0]
            return result

1009
1010
1011
    async def get_weights_by_name(
        self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
    ):
1012
1013
1014
        self.auto_create_handle_loop()
        results = await self.get_weights_by_name_communicator(obj)
        all_parameters = [r.parameter for r in results]
1015
        if self.server_args.dp_size == 1:
1016
            return all_parameters[0]
1017
1018
1019
        else:
            return all_parameters

1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
    async def release_memory_occupation(
        self,
        obj: ReleaseMemoryOccupationReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.release_memory_occupation_communicator(obj)

    async def resume_memory_occupation(
        self,
        obj: ResumeMemoryOccupationReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.resume_memory_occupation_communicator(obj)

1036
1037
1038
1039
1040
1041
1042
1043
    async def slow_down(
        self,
        obj: SlowDownReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.slow_down_communicator(obj)

1044
1045
1046
    async def open_session(
        self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
    ):
1047
        self.auto_create_handle_loop()
1048

1049
1050
1051
1052
1053
        if obj.session_id is None:
            obj.session_id = uuid.uuid4().hex
        elif obj.session_id in self.session_futures:
            return None

1054
        self.send_to_scheduler.send_pyobj(obj)
1055
1056
1057
1058

        self.session_futures[obj.session_id] = asyncio.Future()
        session_id = await self.session_futures[obj.session_id]
        del self.session_futures[obj.session_id]
1059
1060
1061
1062
1063
1064
1065
        return session_id

    async def close_session(
        self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
    ):
        await self.send_to_scheduler.send_pyobj(obj)

1066
    async def get_internal_state(self) -> List[Dict[Any, Any]]:
1067
        req = GetInternalStateReq()
1068
        responses: List[GetInternalStateReqOutput] = (
1069
1070
            await self.get_internal_state_communicator(req)
        )
1071
1072
        # Many DP ranks
        return [res.internal_state for res in responses]
1073

Liangsheng Yin's avatar
Liangsheng Yin committed
1074
1075
1076
1077
1078
1079
1080
1081
    async def get_load(self) -> dict:
        # TODO(lsyin): fake load report server
        if not self.current_load_lock.locked():
            async with self.current_load_lock:
                internal_state = await self.get_internal_state()
                self.current_load = internal_state[0]["load"]
        return {"load": self.current_load}

1082
1083
1084
1085
1086
1087
1088
1089
    async def set_internal_state(
        self, obj: SetInternalStateReq
    ) -> SetInternalStateReqOutput:
        responses: List[SetInternalStateReqOutput] = (
            await self.set_internal_state_communicator(obj)
        )
        return [res.internal_state for res in responses]

1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
    def get_log_request_metadata(self):
        max_length = None
        skip_names = None
        out_skip_names = None
        if self.log_requests:
            if self.log_requests_level == 0:
                max_length = 1 << 30
                skip_names = set(
                    [
                        "text",
                        "input_ids",
                        "input_embeds",
                        "image_data",
                        "audio_data",
                        "lora_path",
1105
1106
1107
1108
1109
1110
1111
                        "sampling_params",
                    ]
                )
                out_skip_names = set(
                    [
                        "text",
                        "output_ids",
1112
1113
1114
                    ]
                )
            elif self.log_requests_level == 1:
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
                max_length = 1 << 30
                skip_names = set(
                    [
                        "text",
                        "input_ids",
                        "input_embeds",
                        "image_data",
                        "audio_data",
                        "lora_path",
                    ]
                )
                out_skip_names = set(
                    [
                        "text",
                        "output_ids",
                    ]
                )
1132
            elif self.log_requests_level == 2:
1133
1134
                max_length = 2048
            elif self.log_requests_level == 3:
1135
1136
1137
1138
1139
1140
1141
                max_length = 1 << 30
            else:
                raise ValueError(
                    f"Invalid --log-requests-level: {self.log_requests_level=}"
                )
        return max_length, skip_names, out_skip_names

1142
    def configure_logging(self, obj: ConfigureLoggingReq):
1143
1144
1145
1146
1147
1148
1149
1150
        if obj.log_requests is not None:
            self.log_requests = obj.log_requests
        if obj.log_requests_level is not None:
            self.log_requests_level = obj.log_requests_level
        if obj.dump_requests_folder is not None:
            self.dump_requests_folder = obj.dump_requests_folder
        if obj.dump_requests_threshold is not None:
            self.dump_requests_threshold = obj.dump_requests_threshold
1151
1152
        if obj.crash_dump_folder is not None:
            self.crash_dump_folder = obj.crash_dump_folder
1153
        logging.info(f"Config logging: {obj=}")
1154
        self.log_request_metadata = self.get_log_request_metadata()
1155

Lianmin Zheng's avatar
Lianmin Zheng committed
1156
    def create_abort_task(self, obj: GenerateReqInput):
1157
1158
        # Abort the request if the client is disconnected.
        async def abort_request():
Lianmin Zheng's avatar
Lianmin Zheng committed
1159
            await asyncio.sleep(2)
1160
1161
1162
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
1163
                for rid in obj.rid:
1164
1165
1166
1167
1168
1169
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

1170
    def auto_create_handle_loop(self):
1171
        if self.no_create_loop:
1172
1173
            return

1174
        self.no_create_loop = True
Lianmin Zheng's avatar
Lianmin Zheng committed
1175
        loop = asyncio.get_event_loop()
1176
1177
1178
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.handle_loop))
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1179

1180
1181
        self.event_loop = loop

1182
1183
1184
1185
        # We cannot add signal handler when the tokenizer manager is not in
        # the main thread due to the CPython limitation.
        if threading.current_thread() is threading.main_thread():
            signal_handler = SignalHandler(self)
1186
1187
1188
1189
1190
            loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
            # Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
            loop.add_signal_handler(
                signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
            )
1191
1192
1193
1194
1195
1196
        else:
            logger.warning(
                "Signal handler is not added because the tokenizer manager is "
                "not in the main thread. This disables graceful shutdown of the "
                "tokenizer manager when SIGTERM is received."
            )
1197
1198
1199
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
        )
1200

1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
    def dump_requests_before_crash(self):
        if self.crash_dump_performed:
            logger.info(
                "SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
            )
            return
        logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
        self.crash_dump_performed = True
        if not self.crash_dump_folder:
            return

        data_to_dump = []
        if self.crash_dump_request_list:
            data_to_dump.extend(self.crash_dump_request_list)

        # Add unfinished requests from rid_to_state
        unfinished_requests = []
        for rid, state in self.rid_to_state.items():
            if not state.finished:
                unfinished_requests.append(
                    (state.obj, {}, state.created_time, time.time())
                )
        if unfinished_requests:
            data_to_dump.extend(unfinished_requests)

        if not data_to_dump:
            return

        filename = os.path.join(
            self.crash_dump_folder,
            os.getenv("HOSTNAME", None),
            f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl',
        )

        os.makedirs(os.path.dirname(filename), exist_ok=True)
        # Include server_args in the dump
        data_to_dump_with_server_args = {
            "server_args": self.server_args,
            "requests": data_to_dump,
        }
        with open(filename, "wb") as f:
            pickle.dump(data_to_dump_with_server_args, f)
        logger.error(
            f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}"
        )

1247
1248
    async def sigterm_watchdog(self):
        while not self.gracefully_exit:
1249
            await asyncio.sleep(5)
1250

1251
        # Drain requests
1252
        while True:
1253
            remain_num_req = len(self.rid_to_state)
1254
1255

            if self.health_check_failed:
1256
                # if health check failed, we should exit immediately
1257
1258
1259
1260
                logger.error(
                    "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
                    remain_num_req,
                )
1261
                self.dump_requests_before_crash()
1262
                break
1263
1264
1265
1266
1267
1268
1269
1270

            elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
                # if force shutdown flag set, exit immediately
                logger.error(
                    "Signal SIGTERM received while force shutdown flag set. Force exiting... remaining number of requests: %d",
                    remain_num_req,
                )
                break
1271

1272
            logger.info(
1273
                f"Gracefully exiting... remaining number of requests {remain_num_req}"
1274
1275
1276
1277
            )
            if remain_num_req > 0:
                await asyncio.sleep(5)
            else:
1278
                self.dump_requests_before_crash()
1279
1280
                break

1281
        kill_process_tree(os.getpid(), include_parent=True)
1282
        sys.exit(0)
1283

Lianmin Zheng's avatar
Lianmin Zheng committed
1284
    async def handle_loop(self):
1285
1286
        """The event loop that handles requests"""

Lianmin Zheng's avatar
Lianmin Zheng committed
1287
        while True:
1288
            recv_obj = await self.recv_from_detokenizer.recv_pyobj()
1289
            self._result_dispatcher(recv_obj)
1290
            self.last_receive_tstamp = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
1291

1292
    def _handle_batch_output(
1293
1294
1295
1296
        self,
        recv_obj: Union[
            BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
        ],
1297
1298
1299
1300
    ):
        for i, rid in enumerate(recv_obj.rids):
            state = self.rid_to_state.get(rid, None)
            if state is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1301
1302
1303
                logger.error(
                    f"Received output for {rid=} but the state was deleted in TokenizerManager."
                )
1304
1305
                continue

1306
            # Build meta_info and return value
1307
1308
1309
1310
1311
1312
1313
1314
1315
            meta_info = {
                "id": rid,
                "finish_reason": recv_obj.finished_reasons[i],
                "prompt_tokens": recv_obj.prompt_tokens[i],
            }

            if getattr(state.obj, "return_logprob", False):
                self.convert_logprob_style(
                    meta_info,
1316
                    state,
1317
                    state.obj.top_logprobs_num,
1318
                    state.obj.token_ids_logprob,
1319
1320
                    state.obj.return_text_in_logprobs
                    and not self.server_args.skip_tokenizer_init,
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
                    recv_obj,
                    i,
                )

            if not isinstance(recv_obj, BatchEmbeddingOut):
                meta_info.update(
                    {
                        "completion_tokens": recv_obj.completion_tokens[i],
                        "cached_tokens": recv_obj.cached_tokens[i],
                    }
                )

1333
            if getattr(recv_obj, "output_hidden_states", None):
1334
1335
1336
                meta_info["hidden_states"] = recv_obj.output_hidden_states[i]

            if isinstance(recv_obj, BatchStrOut):
1337
                state.text += recv_obj.output_strs[i]
1338
                out_dict = {
1339
                    "text": state.text,
1340
1341
1342
                    "meta_info": meta_info,
                }
            elif isinstance(recv_obj, BatchTokenIDOut):
1343
                if self.server_args.stream_output and state.obj.stream:
1344
1345
1346
                    state.output_ids.extend(recv_obj.output_ids[i])
                    output_token_ids = state.output_ids[state.last_output_offset :]
                    state.last_output_offset = len(state.output_ids)
1347
                else:
1348
                    state.output_ids.extend(recv_obj.output_ids[i])
1349
                    output_token_ids = state.output_ids.copy()
1350

1351
                out_dict = {
1352
                    "output_ids": output_token_ids,
1353
1354
                    "meta_info": meta_info,
                }
1355
            elif isinstance(recv_obj, BatchMultimodalOut):
1356
                raise NotImplementedError("BatchMultimodalOut not implemented")
1357
1358
1359
1360
1361
1362
1363
1364
            else:
                assert isinstance(recv_obj, BatchEmbeddingOut)
                out_dict = {
                    "embedding": recv_obj.embeddings[i],
                    "meta_info": meta_info,
                }

            state.finished = recv_obj.finished_reasons[i] is not None
1365
1366
1367
1368
1369
            if state.finished:
                if self.server_args.speculative_algorithm:
                    meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
                state.finished_time = time.time()
                meta_info["e2e_latency"] = state.finished_time - state.created_time
Lianmin Zheng's avatar
Lianmin Zheng committed
1370
                del self.rid_to_state[rid]
1371
1372

            state.out_list.append(out_dict)
1373
1374
            state.event.set()

1375
            # Log metrics and dump
1376
1377
1378
1379
            if self.enable_metrics and state.obj.log_metrics:
                self.collect_metrics(state, recv_obj, i)
            if self.dump_requests_folder and state.finished and state.obj.log_metrics:
                self.dump_requests(state, out_dict)
1380
1381
            if self.crash_dump_folder and state.finished and state.obj.log_metrics:
                self.record_request_for_crash_dump(state, out_dict)
1382
1383
1384
1385

    def convert_logprob_style(
        self,
        meta_info: dict,
1386
        state: ReqState,
1387
        top_logprobs_num: int,
1388
        token_ids_logprob: List[int],
1389
1390
1391
1392
        return_text_in_logprobs: bool,
        recv_obj: BatchStrOut,
        recv_obj_index: int,
    ):
1393
1394
1395
        if recv_obj.input_token_logprobs_val is None:
            return

1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
        if len(recv_obj.input_token_logprobs_val) > 0:
            state.input_token_logprobs_val.extend(
                recv_obj.input_token_logprobs_val[recv_obj_index]
            )
            state.input_token_logprobs_idx.extend(
                recv_obj.input_token_logprobs_idx[recv_obj_index]
            )
        state.output_token_logprobs_val.extend(
            recv_obj.output_token_logprobs_val[recv_obj_index]
        )
        state.output_token_logprobs_idx.extend(
            recv_obj.output_token_logprobs_idx[recv_obj_index]
        )
1409
        meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
1410
1411
            state.input_token_logprobs_val,
            state.input_token_logprobs_idx,
1412
1413
1414
            return_text_in_logprobs,
        )
        meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
1415
1416
            state.output_token_logprobs_val,
            state.output_token_logprobs_idx,
1417
1418
1419
1420
            return_text_in_logprobs,
        )

        if top_logprobs_num > 0:
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
            if len(recv_obj.input_top_logprobs_val) > 0:
                state.input_top_logprobs_val.extend(
                    recv_obj.input_top_logprobs_val[recv_obj_index]
                )
                state.input_top_logprobs_idx.extend(
                    recv_obj.input_top_logprobs_idx[recv_obj_index]
                )
            state.output_top_logprobs_val.extend(
                recv_obj.output_top_logprobs_val[recv_obj_index]
            )
            state.output_top_logprobs_idx.extend(
                recv_obj.output_top_logprobs_idx[recv_obj_index]
            )
1434
            meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1435
1436
                state.input_top_logprobs_val,
                state.input_top_logprobs_idx,
1437
1438
1439
                return_text_in_logprobs,
            )
            meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1440
1441
                state.output_top_logprobs_val,
                state.output_top_logprobs_idx,
1442
1443
1444
                return_text_in_logprobs,
            )

1445
        if token_ids_logprob is not None:
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
            if len(recv_obj.input_token_ids_logprobs_val) > 0:
                state.input_token_ids_logprobs_val.extend(
                    recv_obj.input_token_ids_logprobs_val[recv_obj_index]
                )
                state.input_token_ids_logprobs_idx.extend(
                    recv_obj.input_token_ids_logprobs_idx[recv_obj_index]
                )
            state.output_token_ids_logprobs_val.extend(
                recv_obj.output_token_ids_logprobs_val[recv_obj_index]
            )
            state.output_token_ids_logprobs_idx.extend(
                recv_obj.output_token_ids_logprobs_idx[recv_obj_index]
            )
1459
            meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
1460
1461
                state.input_token_ids_logprobs_val,
                state.input_token_ids_logprobs_idx,
1462
1463
1464
1465
                return_text_in_logprobs,
            )
            meta_info["output_token_ids_logprobs"] = (
                self.detokenize_top_logprobs_tokens(
1466
1467
                    state.output_token_ids_logprobs_val,
                    state.output_token_ids_logprobs_idx,
1468
1469
1470
1471
                    return_text_in_logprobs,
                )
            )

1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
    def detokenize_logprob_tokens(
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
    ):
        if not decode_to_text:
            return [
                (logprob, token_id, None)
                for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
            ]
        else:
            assert self.tokenizer is not None
            token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
            return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))

    def detokenize_top_logprobs_tokens(
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
    ):
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
        ret = []
        for i in range(len(token_logprobs_val)):
            if token_logprobs_val[i]:
                ret.append(
                    self.detokenize_logprob_tokens(
                        token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
                    )
                )
            else:
                ret.append(None)
        return ret

    def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
        completion_tokens = (
            recv_obj.completion_tokens[i]
            if getattr(recv_obj, "completion_tokens", None)
            else 0
        )

1515
1516
1517
1518
        if (
            state.first_token_time == 0.0
            and self.disaggregation_mode != DisaggregationMode.PREFILL
        ):
1519
1520
            state.first_token_time = state.last_time = time.time()
            state.last_completion_tokens = completion_tokens
1521
1522
1523
1524
            self.metrics_collector.observe_time_to_first_token(
                state.first_token_time - state.created_time
            )
        else:
1525
1526
1527
1528
1529
1530
1531
            num_new_tokens = completion_tokens - state.last_completion_tokens
            if num_new_tokens:
                new_time = time.time()
                interval = new_time - state.last_time
                self.metrics_collector.observe_inter_token_latency(
                    interval,
                    num_new_tokens,
1532
                )
1533
1534
                state.last_time = new_time
                state.last_completion_tokens = completion_tokens
1535
1536

        if state.finished:
1537
1538
1539
1540
1541
1542
            has_grammar = (
                state.obj.sampling_params.get("json_schema", None)
                or state.obj.sampling_params.get("regex", None)
                or state.obj.sampling_params.get("ebnf", None)
                or state.obj.sampling_params.get("structural_tag", None)
            )
1543
            self.metrics_collector.observe_one_finished_request(
1544
1545
                recv_obj.prompt_tokens[i],
                completion_tokens,
1546
                recv_obj.cached_tokens[i],
1547
                state.finished_time - state.created_time,
1548
                has_grammar,
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
            )

    def dump_requests(self, state: ReqState, out_dict: dict):
        self.dump_request_list.append(
            (state.obj, out_dict, state.created_time, time.time())
        )

        if len(self.dump_request_list) >= self.dump_requests_threshold:
            filename = os.path.join(
                self.dump_requests_folder,
                datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
            )
            logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")

            to_dump = self.dump_request_list
            self.dump_request_list = []

1566
1567
1568
1569
1570
            to_dump_with_server_args = {
                "server_args": self.server_args,
                "requests": to_dump,
            }

1571
1572
1573
            def background_task():
                os.makedirs(self.dump_requests_folder, exist_ok=True)
                with open(filename, "wb") as f:
1574
                    pickle.dump(to_dump_with_server_args, f)
1575
1576
1577
1578

            # Schedule the task to run in the background without awaiting it
            asyncio.create_task(asyncio.to_thread(background_task))

1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
    def record_request_for_crash_dump(self, state: ReqState, out_dict: dict):
        current_time = time.time()
        self.crash_dump_request_list.append(
            (state.obj, out_dict, state.created_time, current_time)
        )
        # Remove requests older than 5 minutes based on finish time
        while (
            self.crash_dump_request_list
            and current_time - self.crash_dump_request_list[0][3] >= 300
        ):
            self.crash_dump_request_list.popleft()

Lianmin Zheng's avatar
Lianmin Zheng committed
1591
    def _handle_abort_req(self, recv_obj):
1592
        self.rid_to_state.pop(recv_obj.rid, None)
Lianmin Zheng's avatar
Lianmin Zheng committed
1593

1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
    def _handle_open_session_req_output(self, recv_obj):
        self.session_futures[recv_obj.session_id].set_result(
            recv_obj.session_id if recv_obj.success else None
        )

    def _handle_update_weights_from_disk_req_output(self, recv_obj):
        if self.server_args.dp_size == 1:
            self.model_update_result.set_result(recv_obj)
        else:  # self.server_args.dp_size > 1
            self.model_update_tmp.append(recv_obj)
1604
            # set future if the all results are received
1605
1606
1607
            if len(self.model_update_tmp) == self.server_args.dp_size:
                self.model_update_result.set_result(self.model_update_tmp)

1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
    async def score_request(
        self,
        query: Optional[Union[str, List[int]]] = None,
        items: Optional[Union[str, List[str], List[List[int]]]] = None,
        label_token_ids: Optional[List[int]] = None,
        apply_softmax: bool = False,
        item_first: bool = False,
        request: Optional[Any] = None,
    ) -> List[List[float]]:
        """
        See Engine.score() for more details.
        """
        if label_token_ids is None:
            raise ValueError("label_token_ids must be provided")

        if self.tokenizer is not None:
            vocab_size = self.tokenizer.vocab_size
            for token_id in label_token_ids:
                if token_id >= vocab_size:
                    raise ValueError(
                        f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
                    )

        # Handle string or tokenized query/items
        if isinstance(query, str) and (
            isinstance(items, str)
            or (isinstance(items, list) and (not items or isinstance(items[0], str)))
        ):
            # Both query and items are text
            items_list = [items] if isinstance(items, str) else items
            if item_first:
                prompts = [f"{item}{query}" for item in items_list]
            else:
                prompts = [f"{query}{item}" for item in items_list]
            batch_request = GenerateReqInput(
                text=prompts,
                return_logprob=True,
                token_ids_logprob=label_token_ids,
                stream=False,
                sampling_params={"max_new_tokens": 1},
            )
        elif (
            isinstance(query, list)
            and isinstance(items, list)
            and items
            and isinstance(items[0], list)
        ):
            # Both query and items are token IDs
            if item_first:
                input_ids_list = [item + query for item in items]
            else:
                input_ids_list = [query + item for item in items]
            batch_request = GenerateReqInput(
                input_ids=input_ids_list,
                return_logprob=True,
                token_ids_logprob=label_token_ids,
                stream=False,
                sampling_params={"max_new_tokens": 1},
            )
        else:
            raise ValueError(
                "Invalid combination of query/items types for score_request."
            )

        results = await self.generate_request(batch_request, request).__anext__()
        scores = []

        for result in results:
            # Get logprobs for each token
            logprobs = {}
            for logprob, token_id, _ in result["meta_info"].get(
                "output_token_ids_logprobs", []
            )[0]:
                if token_id in label_token_ids:
                    logprobs[token_id] = logprob

            # Get scores in order of label_token_ids
            score_list = [
                logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
            ]

            # Apply softmax to logprobs if needed
            if apply_softmax:
                score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
            else:
                # Convert logprobs to probabilities if not using softmax
                score_list = [
                    math.exp(x) if x != float("-inf") else 0.0 for x in score_list
                ]

            scores.append(score_list)

        return scores

1702

1703
1704
1705
1706
1707
1708
1709
1710
1711
async def print_exception_wrapper(func):
    """
    Sometimes an asyncio function does not print exception.
    We do another wrapper to handle the exception.
    """
    try:
        await func()
    except Exception:
        traceback = get_exception_traceback()
1712
        logger.error(f"TokenizerManager hit an exception: {traceback}")
1713
1714
        if hasattr(func, "__self__") and isinstance(func.__self__, TokenizerManager):
            func.__self__.dump_requests_before_crash()
1715
1716
1717
1718
        kill_process_tree(os.getpid(), include_parent=True)
        sys.exit(1)


1719
class SignalHandler:
1720
    def __init__(self, tokenizer_manager: TokenizerManager):
1721
        self.tokenizer_manager = tokenizer_manager
1722

1723
    def sigterm_handler(self, signum=None, frame=None):
1724
1725
1726
        logger.warning(
            f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
        )
1727
        self.tokenizer_manager.gracefully_exit = True
1728

1729
1730
1731
1732
    def running_phase_sigquit_handler(self, signum=None, frame=None):
        logger.error(
            "Received sigquit from a child process. It usually means the child failed."
        )
1733
        self.tokenizer_manager.dump_requests_before_crash()
1734
1735
        kill_process_tree(os.getpid())

1736
1737
1738
1739
1740

T = TypeVar("T")


class _Communicator(Generic[T]):
1741
1742
    """Note: The communicator now only run up to 1 in-flight request at any time."""

1743
1744
1745
    def __init__(self, sender, fan_out: int):
        self._sender = sender
        self._fan_out = fan_out
1746
        self._result_event: Optional[asyncio.Event] = None
1747
        self._result_values: Optional[List[T]] = None
1748
        self._ready_queue: Deque[asyncio.Future] = deque()
1749
1750

    async def __call__(self, obj):
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
        ready_event = asyncio.Event()
        if self._result_event is not None or len(self._ready_queue) > 0:
            self._ready_queue.append(ready_event)
            await ready_event.wait()
            assert self._result_event is None
            assert self._result_values is None

        if obj:
            self._sender.send_pyobj(obj)

        self._result_event = asyncio.Event()
1762
        self._result_values = []
1763
        await self._result_event.wait()
1764
        result_values = self._result_values
1765
1766
1767
1768
1769
        self._result_event = self._result_values = None

        if len(self._ready_queue) > 0:
            self._ready_queue.popleft().set()

1770
1771
1772
1773
1774
        return result_values

    def handle_recv(self, recv_obj: T):
        self._result_values.append(recv_obj)
        if len(self._result_values) == self._fan_out:
1775
            self._result_event.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787


# Note: request abort handling logic
# We should handle all of the following cases correctly.
#
# | entrypoint | is_streaming | status          | abort engine    | cancel asyncio task   | rid_to_state                |
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
# | http       | yes          | waiting queue   | background task | fast api              | del in _handle_abort_req    |
# | http       | yes          | running         | background task | fast api              | del in _handle_batch_output |
# | http       | no           | waiting queue   | type 1          | type 1 exception      | del in _handle_abort_req    |
# | http       | no           | running         | type 3          | type 3 exception      | del in _handle_batch_output |
#