tokenizer_manager.py 80.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
"""TokenizerManager is a process that tokenizes the text."""
15

Lianmin Zheng's avatar
Lianmin Zheng committed
16
import asyncio
17
18
import copy
import dataclasses
19
import json
20
import logging
21
import math
Lianmin Zheng's avatar
Lianmin Zheng committed
22
import os
23
import pickle
24
25
import signal
import sys
26
import threading
27
import time
28
import uuid
29
from collections import deque
fzyzcjy's avatar
fzyzcjy committed
30
from contextlib import nullcontext
31
from datetime import datetime
32
from enum import Enum
33
from http import HTTPStatus
34
35
36
37
38
39
40
41
42
43
44
45
from typing import (
    Any,
    Awaitable,
    Deque,
    Dict,
    Generic,
    List,
    Optional,
    Tuple,
    TypeVar,
    Union,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
46

47
import fastapi
48
import torch
Lianmin Zheng's avatar
Lianmin Zheng committed
49
50
51
import uvloop
import zmq
import zmq.asyncio
52
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
53

54
from sglang.srt.aio_rwlock import RWLock
55
from sglang.srt.configs.model_config import ModelConfig
56
57
58
59
60
61
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
    KVClassType,
    TransferBackend,
    get_kv_class,
)
xm:D's avatar
xm:D committed
62
63
64
65
66
from sglang.srt.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
67
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
Lianmin Zheng's avatar
Lianmin Zheng committed
68
from sglang.srt.managers.io_struct import (
69
    AbortReq,
70
    BatchEmbeddingOut,
71
    BatchMultimodalOut,
Lianmin Zheng's avatar
Lianmin Zheng committed
72
    BatchStrOut,
73
    BatchTokenIDOut,
74
    CloseSessionReqInput,
75
    ConfigureLoggingReq,
76
    EmbeddingReqInput,
77
    ExpertDistributionReq,
78
    ExpertDistributionReqOutput,
79
80
    FlushCacheReqInput,
    FlushCacheReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
81
    GenerateReqInput,
82
83
    GetInternalStateReq,
    GetInternalStateReqOutput,
84
85
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
86
    HealthCheckOutput,
87
88
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
89
90
91
    LoadLoRAAdapterReqInput,
    LoadLoRAAdapterReqOutput,
    LoRAUpdateResult,
92
93
    OpenSessionReqInput,
    OpenSessionReqOutput,
94
    ProfileReq,
95
96
    ProfileReqOutput,
    ProfileReqType,
97
98
99
100
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
101
    SessionParams,
102
103
    SetInternalStateReq,
    SetInternalStateReqOutput,
104
105
    SlowDownReqInput,
    SlowDownReqOutput,
106
107
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
108
109
    UnloadLoRAAdapterReqInput,
    UnloadLoRAAdapterReqOutput,
Chayenne's avatar
Chayenne committed
110
111
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
112
113
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
114
115
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
116
)
Mick's avatar
Mick committed
117
from sglang.srt.managers.mm_utils import TensorTransportMode
118
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
119
from sglang.srt.managers.scheduler import is_health_check_generate_req
fzyzcjy's avatar
fzyzcjy committed
120
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
121
122
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
123
from sglang.srt.server_args import PortArgs, ServerArgs
124
125
from sglang.srt.utils import (
    dataclass_to_string_truncated,
126
    get_bool_env_var,
127
128
129
    get_zmq_socket,
    kill_process_tree,
)
130
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
Lianmin Zheng's avatar
Lianmin Zheng committed
131
132
133

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

134
135
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
136

137
138
139
140
@dataclasses.dataclass
class ReqState:
    """Store the state a request."""

141
    out_list: List[Dict[Any, Any]]
142
143
    finished: bool
    event: asyncio.Event
144
    obj: Union[GenerateReqInput, EmbeddingReqInput]
145
146
147

    # For metrics
    created_time: float
148
149
150
151
    finished_time: float = 0.0
    first_token_time: float = 0.0
    last_time: float = 0.0
    last_completion_tokens: int = 1
152
153
154

    # For streaming output
    last_output_offset: int = 0
155

156
    # For incremental state update.
157
    # TODO(lianmin): do not initialize some lists if not needed.
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    text: str = ""
    output_ids: List[int] = dataclasses.field(default_factory=list)
    input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
    input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
    output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
    output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
    input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
    input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
    output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
    output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
    input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
    input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
    output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
    output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
172
173


174
175
class TokenizerManager:
    """TokenizerManager is a process that tokenizes the text."""
176

Lianmin Zheng's avatar
Lianmin Zheng committed
177
178
179
180
181
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
182
        # Parse args
Liangsheng Yin's avatar
Liangsheng Yin committed
183
        self.server_args = server_args
184
        self.enable_metrics = server_args.enable_metrics
185
        self.log_requests = server_args.log_requests
186
        self.log_requests_level = server_args.log_requests_level
187
188
189
190
191
        self.preferred_sampling_params = (
            json.loads(server_args.preferred_sampling_params)
            if server_args.preferred_sampling_params
            else None
        )
192
        self.crash_dump_folder = server_args.crash_dump_folder
Lianmin Zheng's avatar
Lianmin Zheng committed
193

194
        # Read model args
Lianmin Zheng's avatar
Lianmin Zheng committed
195
        self.model_path = server_args.model_path
196
        self.served_model_name = server_args.served_model_name
197
        self.model_config = ModelConfig.from_server_args(server_args)
198
        self.is_generation = self.model_config.is_generation
199
        self.is_image_gen = self.model_config.is_image_gen
200
201
        self.context_len = self.model_config.context_len
        self.image_token_id = self.model_config.image_token_id
202
        self.max_req_input_len = None  # Will be set later in engine.py
203

204
        if self.model_config.is_multimodal:
Mick's avatar
Mick committed
205
            import_processors()
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
            try:
                _processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                    use_fast=not server_args.disable_fast_image_processor,
                )
            except ValueError as e:
                error_message = str(e)
                if "does not have a slow version" in error_message:
                    logger.info(
                        f"Processor {server_args.tokenizer_path} does not have a slow version. Automatically use fast version"
                    )
                    _processor = get_processor(
                        server_args.tokenizer_path,
                        tokenizer_mode=server_args.tokenizer_mode,
                        trust_remote_code=server_args.trust_remote_code,
                        revision=server_args.revision,
                        use_fast=True,
                    )
                else:
                    raise e
Mick's avatar
Mick committed
229
            transport_mode = _determine_tensor_transport_mode(self.server_args)
230
231

            # We want to parallelize the image pre-processing so we create an executor for it
Mick's avatar
Mick committed
232
            # We create mm_processor for any skip_tokenizer_init to make sure we still encode
233
            # images even with skip_tokenizer_init=False.
Mick's avatar
Mick committed
234
            self.mm_processor = get_mm_processor(
Mick's avatar
Mick committed
235
                self.model_config.hf_config, server_args, _processor, transport_mode
236
237
238
239
240
241
            )

            if server_args.skip_tokenizer_init:
                self.tokenizer = self.processor = None
            else:
                self.processor = _processor
xm:D's avatar
xm:D committed
242
                self.tokenizer = get_tokenizer_from_processor(self.processor)
243
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
244
        else:
245
            self.mm_processor = self.processor = None
246

247
            if server_args.skip_tokenizer_init:
248
                self.tokenizer = None
249
250
251
252
253
254
255
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
256

257
        # Init inter-process communication
258
        context = zmq.asyncio.Context(2)
259
260
261
        self.recv_from_detokenizer = get_zmq_socket(
            context, zmq.PULL, port_args.tokenizer_ipc_name, True
        )
262
263
264
        self.send_to_scheduler = get_zmq_socket(
            context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
        )
265

266
        # Request states
267
        self.no_create_loop = False
268
        self.rid_to_state: Dict[str, ReqState] = {}
269
270
271
        self.asyncio_tasks = set()

        # Health check
272
        self.health_check_failed = False
273
274
        self.gracefully_exit = False
        self.last_receive_tstamp = 0
275
        self.server_status = ServerStatus.Starting
276
277

        # Dumping
278
279
280
        self.dump_requests_folder = ""  # By default do not dump
        self.dump_requests_threshold = 1000
        self.dump_request_list: List[Tuple] = []
281
        self.log_request_metadata = self.get_log_request_metadata()
282
283
284
285
        self.crash_dump_request_list: deque[Tuple] = deque()
        self.crash_dump_performed = False  # Flag to ensure dump is only called once

        # Session
286
        self.session_futures = {}  # session_id -> asyncio event
Lianmin Zheng's avatar
Lianmin Zheng committed
287

288
        # Weight updates
289
290
291
292
293
        # The event to notify the weight sync is finished.
        self.model_update_lock = RWLock()
        self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
            None
        )
294
295
        self._is_updating = False
        self._is_updating_cond = asyncio.Condition()
296

297
298
299
300
301
302
        # LoRA
        # Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
        # The registry dynamically updates as adapters are loaded / unloaded during runtime. It
        # serves as the source of truth for available adapters and maps user-friendly LoRA names
        # to internally used unique LoRA IDs.
        self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
303
304
305
306
307
        # Lock to serialize LoRA update operations.
        # Please note that, unlike `model_update_lock`, this does not block inference, allowing
        # LoRA updates and inference to overlap.
        self.lora_update_lock = asyncio.Lock()

308
        # For PD disaggregtion
309
310
311
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
312
        self.disaggregation_transfer_backend = TransferBackend(
313
314
315
316
317
            self.server_args.disaggregation_transfer_backend
        )
        # Start kv boostrap server on prefill
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            # only start bootstrap server on prefill tm
318
319
320
321
322
323
324
325
326
327
328
329
330
            kv_bootstrap_server_class = get_kv_class(
                self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
            )
            self.bootstrap_server = kv_bootstrap_server_class(
                self.server_args.disaggregation_bootstrap_port
            )
            is_create_store = (
                self.server_args.node_rank == 0
                and self.server_args.disaggregation_transfer_backend == "ascend"
            )
            if is_create_store:
                try:
                    from mf_adapter import create_config_store
331

332
333
334
335
336
337
                    ascend_url = os.getenv("ASCEND_MF_STORE_URL")
                    create_config_store(ascend_url)
                except Exception as e:
                    error_message = f"Failed create mf store, invalid ascend_url."
                    error_message += f" With exception {e}"
                    raise error_message
338

339
340
341
        # For load balancing
        self.current_load = 0
        self.current_load_lock = asyncio.Lock()
342
343
344
345
346
347

        # Metrics
        if self.enable_metrics:
            self.metrics_collector = TokenizerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
348
                    # TODO: Add lora name/path in the future,
349
                },
350
351
352
353
                bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
                bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
                bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
                collect_tokens_histogram=self.server_args.collect_tokens_histogram,
354
355
356
            )

        # Communicators
357
358
359
360
361
362
        self.init_weights_update_group_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.update_weights_from_distributed_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
363
364
365
        self.update_weights_from_tensor_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
366
367
368
        self.get_weights_by_name_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
369
370
371
372
373
374
        self.release_memory_occupation_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.resume_memory_occupation_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
375
376
377
        self.slow_down_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
378
379
380
        self.flush_cache_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
381
        self.profile_communicator = _Communicator(
382
383
384
385
386
            self.send_to_scheduler, server_args.dp_size
        )
        self.get_internal_state_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
387
388
389
        self.set_internal_state_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
390
391
392
        self.expert_distribution_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
393
394
395
        self.update_lora_adapter_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
396

397
        self._result_dispatcher = TypeBasedDispatcher(
398
            [
399
                (
400
401
402
403
404
405
                    (
                        BatchStrOut,
                        BatchEmbeddingOut,
                        BatchTokenIDOut,
                        BatchMultimodalOut,
                    ),
406
                    self._handle_batch_output,
407
                ),
Lianmin Zheng's avatar
Lianmin Zheng committed
408
                (AbortReq, self._handle_abort_req),
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
                (OpenSessionReqOutput, self._handle_open_session_req_output),
                (
                    UpdateWeightFromDiskReqOutput,
                    self._handle_update_weights_from_disk_req_output,
                ),
                (
                    InitWeightsUpdateGroupReqOutput,
                    self.init_weights_update_group_communicator.handle_recv,
                ),
                (
                    UpdateWeightsFromDistributedReqOutput,
                    self.update_weights_from_distributed_communicator.handle_recv,
                ),
                (
                    UpdateWeightsFromTensorReqOutput,
                    self.update_weights_from_tensor_communicator.handle_recv,
                ),
                (
                    GetWeightsByNameReqOutput,
                    self.get_weights_by_name_communicator.handle_recv,
                ),
                (
                    ReleaseMemoryOccupationReqOutput,
                    self.release_memory_occupation_communicator.handle_recv,
                ),
                (
                    ResumeMemoryOccupationReqOutput,
                    self.resume_memory_occupation_communicator.handle_recv,
                ),
438
439
440
441
                (
                    SlowDownReqOutput,
                    self.slow_down_communicator.handle_recv,
                ),
442
443
444
445
                (
                    FlushCacheReqOutput,
                    self.flush_cache_communicator.handle_recv,
                ),
446
447
                (
                    ProfileReqOutput,
448
                    self.profile_communicator.handle_recv,
449
450
451
452
453
                ),
                (
                    GetInternalStateReqOutput,
                    self.get_internal_state_communicator.handle_recv,
                ),
454
455
456
457
                (
                    SetInternalStateReqOutput,
                    self.set_internal_state_communicator.handle_recv,
                ),
458
459
460
461
                (
                    ExpertDistributionReqOutput,
                    self.expert_distribution_communicator.handle_recv,
                ),
462
463
464
465
                (
                    LoRAUpdateResult,
                    self.update_lora_adapter_communicator.handle_recv,
                ),
466
                (HealthCheckOutput, lambda x: None),
467
468
469
            ]
        )

470
    async def generate_request(
471
        self,
472
        obj: Union[GenerateReqInput, EmbeddingReqInput],
473
        request: Optional[fastapi.Request] = None,
474
    ):
475
        created_time = time.time()
476
        self.auto_create_handle_loop()
477
        obj.normalize_batch_and_arguments()
478

479
480
        async with self._is_updating_cond:
            await self._is_updating_cond.wait_for(lambda: not self._is_updating)
481
482

        if self.log_requests:
483
            max_length, skip_names, _ = self.log_request_metadata
484
            logger.info(
485
                f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
486
487
            )

488
        async with self.model_update_lock.reader_lock:
489
            if obj.is_single:
490
                tokenized_obj = await self._tokenize_one_request(obj)
491
492
                state = self._send_one_request(obj, tokenized_obj, created_time)
                async for response in self._wait_one_response(obj, state, request):
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
                    yield response
            else:
                async for response in self._handle_batch_request(
                    obj, request, created_time
                ):
                    yield response

    async def _tokenize_one_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
    ):
        """Tokenize one request."""
        # Tokenize
        input_embeds = None
        input_text = obj.text
woodx's avatar
woodx committed
508
509
510
511
        token_type_ids = None
        is_cross_encoder_request = (
            isinstance(obj, EmbeddingReqInput) and obj.is_cross_encoder_request
        )
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
        if obj.input_embeds is not None:
            if not self.server_args.disable_radix_cache:
                raise ValueError(
                    "input_embeds is provided while disable_radix_cache is False. "
                    "Please add `--disable-radix-cache` when you launch the server "
                    "if you want to use input_embeds as inputs."
                )
            input_embeds = obj.input_embeds
            input_ids = obj.input_ids
        elif obj.input_ids is not None:
            input_ids = obj.input_ids
        else:
            if self.tokenizer is None:
                raise ValueError(
                    "The engine initialized with skip_tokenizer_init=True cannot "
                    "accept text prompts. Please provide input_ids or re-initialize "
                    "the engine with skip_tokenizer_init=False."
                )
woodx's avatar
woodx committed
530
531
532
533
534
535
536
537
            encoded = self.tokenizer(
                input_text, return_token_type_ids=is_cross_encoder_request
            )

            input_ids = encoded["input_ids"]
            if is_cross_encoder_request:
                input_ids = encoded["input_ids"][0]
                token_type_ids = encoded.get("token_type_ids", [None])[0]
538

539
        if self.mm_processor and obj.contains_mm_input():
540
541
542
543
544
            if not isinstance(obj.image_data, list):
                obj.image_data = [obj.image_data]
            if not isinstance(obj.audio_data, list):
                obj.audio_data = [obj.audio_data]
            mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
545
                image_data=obj.image_data,
546
                audio_data=obj.audio_data,
547
548
549
550
                input_text=input_text or input_ids,
                request_obj=obj,
                max_req_input_len=self.max_req_input_len,
            )
551
552
            if mm_inputs and "input_ids" in mm_inputs:
                input_ids = mm_inputs["input_ids"]
553
        else:
554
            mm_inputs = None
555

556
        if self.server_args.enable_lora and obj.lora_path:
557
558
            # Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
            # `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
559
            obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
560

561
        self._validate_one_request(obj, input_ids)
562
        return self._create_tokenized_object(
563
            obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
564
565
        )

566
    def _validate_one_request(
567
568
569
        self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
    ) -> None:
        """Validates that the input token count and the requested token count doesn't exceed the model's context length."""
570
571

        input_token_num = len(input_ids) if input_ids is not None else 0
572
        # Check if input alone exceeds context length
573
574
575
576
577
578
        if input_token_num >= self.context_len:
            raise ValueError(
                f"The input ({input_token_num} tokens) is longer than the "
                f"model's context length ({self.context_len} tokens)."
            )

579
580
581
582
583
584
        if isinstance(obj, EmbeddingReqInput) and self.is_generation:
            raise ValueError(
                "This model does not appear to be an embedding model by default. "
                "Please add `--is-embedding` when launching the server or try another model."
            )

585
586
        # Check total tokens (input + max_new_tokens)
        max_new_tokens = obj.sampling_params.get("max_new_tokens")
587
        if (
588
589
            max_new_tokens is not None
            and (max_new_tokens + input_token_num) >= self.context_len
590
        ):
591
592
            total_tokens = max_new_tokens + input_token_num
            error_msg = (
593
                f"Requested token count exceeds the model's maximum context length "
594
                f"of {self.context_len} tokens. You requested a total of {total_tokens} "
595
                f"tokens: {input_token_num} tokens from the input messages and "
596
597
598
599
600
                f"{max_new_tokens} tokens for the completion. Please reduce the number "
                f"of tokens in the input messages or the completion to fit within the limit."
            )
            raise ValueError(error_msg)

601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
        if isinstance(obj, GenerateReqInput):
            if (
                obj.return_hidden_states
                and not self.server_args.enable_return_hidden_states
            ):
                raise ValueError(
                    "The server is not configured to return the hidden states. "
                    "Please set `--enable-return-hidden-states` to enable this feature."
                )
            if (
                obj.custom_logit_processor
                and not self.server_args.enable_custom_logit_processor
            ):
                raise ValueError(
                    "The server is not configured to enable custom logit processor. "
                    "Please set `--enable-custom-logits-processor` to enable this feature."
                )

619
620
621
622
623
624
625
626
    def _validate_input_ids_in_vocab(
        self, input_ids: List[int], vocab_size: int
    ) -> None:
        if any(id >= vocab_size for id in input_ids):
            raise ValueError(
                f"The input_ids {input_ids} contains values greater than the vocab size ({vocab_size})."
            )

627
628
629
630
631
632
    def _create_tokenized_object(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        input_text: str,
        input_ids: List[int],
        input_embeds: Optional[Union[List[float], None]] = None,
633
        mm_inputs: Optional[Dict] = None,
woodx's avatar
woodx committed
634
        token_type_ids: Optional[List[int]] = None,
635
636
    ) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
        """Create a tokenized request object from common parameters."""
637
638
639
640
641
642
643
644
        # Parse sampling parameters
        # Note: if there are preferred sampling params, we use them if they are not
        # explicitly passed in sampling_params
        if self.preferred_sampling_params:
            sampling_kwargs = {**self.preferred_sampling_params, **obj.sampling_params}
        else:
            sampling_kwargs = obj.sampling_params
        sampling_params = SamplingParams(**sampling_kwargs)
645
        sampling_params.normalize(self.tokenizer)
646
        sampling_params.verify(self.model_config.vocab_size)
647
648
649

        # Build return object
        if isinstance(obj, GenerateReqInput):
650
651
652
653
            session_params = (
                SessionParams(**obj.session_params) if obj.session_params else None
            )

654
655
656
657
            tokenized_obj = TokenizedGenerateReqInput(
                obj.rid,
                input_text,
                input_ids,
658
                mm_inputs,
659
                sampling_params,
660
661
662
663
                obj.return_logprob,
                obj.logprob_start_len,
                obj.top_logprobs_num,
                obj.token_ids_logprob,
664
                obj.stream,
665
                bootstrap_host=obj.bootstrap_host,
666
                bootstrap_port=obj.bootstrap_port,
667
                bootstrap_room=obj.bootstrap_room,
668
                lora_id=obj.lora_id,
669
670
671
                input_embeds=input_embeds,
                session_params=session_params,
                custom_logit_processor=obj.custom_logit_processor,
672
                return_hidden_states=obj.return_hidden_states,
673
                data_parallel_rank=obj.data_parallel_rank,
674
675
676
677
678
679
            )
        elif isinstance(obj, EmbeddingReqInput):
            tokenized_obj = TokenizedEmbeddingReqInput(
                obj.rid,
                input_text,
                input_ids,
680
                mm_inputs,
woodx's avatar
woodx committed
681
                token_type_ids,
682
683
684
685
686
                sampling_params,
            )

        return tokenized_obj

687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
    async def _batch_tokenize_and_process(
        self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
    ) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]:
        """Handle batch tokenization for text inputs only."""
        logger.debug(f"Starting batch tokenization for {batch_size} text requests")

        # Collect requests and texts
        requests = [obj[i] for i in range(batch_size)]
        texts = [req.text for req in requests]

        # Batch tokenize all texts
        encoded = self.tokenizer(texts)
        input_ids_list = encoded["input_ids"]

        # Process all requests
        tokenized_objs = []
        for i, req in enumerate(requests):
            self._validate_token_len(obj[i], input_ids_list[i])
            tokenized_objs.append(
                self._create_tokenized_object(
                    req, req.text, input_ids_list[i], None, None
                )
            )
        logger.debug(f"Completed batch processing for {batch_size} requests")
        return tokenized_objs

    def _validate_batch_tokenization_constraints(
        self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
    ) -> None:
        """Validate constraints for batch tokenization processing."""
        for i in range(batch_size):
718
            if self.is_generation and obj[i].contains_mm_input():
719
                raise ValueError(
720
                    "For multimodal input processing do not set `enable_tokenizer_batch_encode`."
721
722
723
724
725
726
727
728
729
730
                )
            if obj[i].input_ids is not None:
                raise ValueError(
                    "Batch tokenization is not needed for pre-tokenized input_ids. Do not set `enable_tokenizer_batch_encode`."
                )
            if obj[i].input_embeds is not None:
                raise ValueError(
                    "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
                )

731
732
733
734
735
736
    def _send_one_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
        created_time: Optional[float] = None,
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
737
        self.send_to_scheduler.send_pyobj(tokenized_obj)
738
        state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
739
        self.rid_to_state[obj.rid] = state
740
        return state
741
742
743
744

    async def _wait_one_response(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
745
        state: ReqState,
746
747
748
749
750
751
752
        request: Optional[fastapi.Request] = None,
    ):
        """Wait for the response of one request."""
        while True:
            try:
                await asyncio.wait_for(state.event.wait(), timeout=4)
            except asyncio.TimeoutError:
753
754
755
756
757
                if (
                    request is not None
                    and not obj.background
                    and await request.is_disconnected()
                ):
Lianmin Zheng's avatar
Lianmin Zheng committed
758
                    # Abort the request for disconnected requests (non-streaming, waiting queue)
759
                    self.abort_request(obj.rid)
Lianmin Zheng's avatar
Lianmin Zheng committed
760
                    # Use exception to kill the whole call stack and asyncio task
761
                    raise ValueError(
Lianmin Zheng's avatar
Lianmin Zheng committed
762
                        f"Request is disconnected from the client side (type 1). Abort request {obj.rid=}"
763
                    )
764
765
766
767
768
769
770
                continue

            out = state.out_list[-1]

            state.out_list = []
            if state.finished:
                if self.log_requests:
771
772
773
774
775
                    max_length, skip_names, out_skip_names = self.log_request_metadata
                    if self.model_config.is_multimodal_gen:
                        msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
                    else:
                        msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
776
777
                    logger.info(msg)

778
779
                # Mark ongoing LoRA request as finished.
                if self.server_args.enable_lora and obj.lora_path:
780
                    await self.lora_registry.release(obj.lora_id)
781

782
783
784
785
786
787
788
789
790
                # Check if this was an abort/error created by scheduler
                if isinstance(out["meta_info"].get("finish_reason"), dict):
                    finish_reason = out["meta_info"]["finish_reason"]
                    if (
                        finish_reason.get("type") == "abort"
                        and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST
                    ):
                        raise ValueError(finish_reason["message"])

791
792
793
794
795
796
797
798
799
800
801
802
803
                    if (
                        finish_reason.get("type") == "abort"
                        and finish_reason.get("status_code")
                        == HTTPStatus.SERVICE_UNAVAILABLE
                    ):
                        # This is an abort request initiated by scheduler.
                        # Delete the key to prevent resending abort request to the scheduler and
                        # to ensure aborted request state is cleaned up.
                        del self.rid_to_state[state.obj.rid]
                        raise fastapi.HTTPException(
                            status_code=finish_reason["status_code"],
                            detail=finish_reason["message"],
                        )
804
805
806
807
808
809
810
811
                yield out
                break

            state.event.clear()

            if obj.stream:
                yield out
            else:
812
813
814
815
816
                if (
                    request is not None
                    and not obj.background
                    and await request.is_disconnected()
                ):
Lianmin Zheng's avatar
Lianmin Zheng committed
817
                    # Abort the request for disconnected requests (non-streaming, running)
818
                    self.abort_request(obj.rid)
Lianmin Zheng's avatar
Lianmin Zheng committed
819
                    # Use exception to kill the whole call stack and asyncio task
820
                    raise ValueError(
Lianmin Zheng's avatar
Lianmin Zheng committed
821
                        f"Request is disconnected from the client side (type 3). Abort request {obj.rid=}"
822
                    )
823
824
825
826
827
828
829
830
831
832
833
834

    async def _handle_batch_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
        created_time: Optional[float] = None,
    ):
        batch_size = obj.batch_size

        generators = []
        rids = []
        if getattr(obj, "parallel_sample_num", 1) == 1:
835
836
837
838
839
840
841
842
            if self.server_args.enable_tokenizer_batch_encode:
                # Validate batch tokenization constraints
                self._validate_batch_tokenization_constraints(batch_size, obj)

                tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)

                for i, tokenized_obj in enumerate(tokenized_objs):
                    tmp_obj = obj[i]
843
844
                    state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, state, request))
845
846
847
                    rids.append(tmp_obj.rid)
            else:
                # Sequential tokenization and processing
fzyzcjy's avatar
fzyzcjy committed
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
                with (
                    input_blocker_guard_region(send_to_scheduler=self.send_to_scheduler)
                    if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
                    else nullcontext()
                ):
                    for i in range(batch_size):
                        tmp_obj = obj[i]
                        tokenized_obj = await self._tokenize_one_request(tmp_obj)
                        state = self._send_one_request(
                            tmp_obj, tokenized_obj, created_time
                        )
                        generators.append(
                            self._wait_one_response(tmp_obj, state, request)
                        )
                        rids.append(tmp_obj.rid)
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
        else:
            # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
            if batch_size > 128:
                logger.warning(
                    "Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
                    "The performance might be better if you just duplicate the requests n times or use "
                    "many threads to send them one by one with parallel sampling (n > 1)."
                )

            # Tokenize all requests
            objs = [obj[i] for i in range(batch_size)]
            tokenized_objs = await asyncio.gather(
                *(self._tokenize_one_request(obj) for obj in objs)
            )

            # Cache the common prefix for parallel sampling
            for i in range(batch_size):
                tmp_obj = copy.copy(objs[i])
                tokenized_obj = copy.copy(tokenized_objs[i])
                tokenized_obj.rid = tmp_obj.regenerate_rid()
                tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
                tokenized_obj.sampling_params.max_new_tokens = 0
                tokenized_obj.stream = False
886
887
                state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
                await self._wait_one_response(tmp_obj, state, request).__anext__()
888
889
890
891
892
893
894

            # Expand requests, assign new rids for them, and send them
            for i in range(batch_size):
                for _ in range(obj.parallel_sample_num):
                    tmp_obj = copy.copy(objs[i])
                    tokenized_obj = copy.copy(tokenized_objs[i])
                    tokenized_obj.rid = tmp_obj.regenerate_rid()
895
896
                    state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, state, request))
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
                    rids.append(tmp_obj.rid)

        # Wait for all requests
        is_stream = hasattr(obj, "stream") and obj.stream
        if not is_stream:
            outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
            yield outputs
        else:
            rid_to_index = {rid: i for i, rid in enumerate(rids)}
            task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
            while task_map:
                done, _ = await asyncio.wait(
                    task_map.keys(), return_when=asyncio.FIRST_COMPLETED
                )

                for task in done:
                    gen = task_map.pop(task)
                    try:
                        result = task.result()
                        result["index"] = rid_to_index[result["meta_info"]["id"]]
                        yield result
                        new_task = asyncio.create_task(gen.__anext__())
                        task_map[new_task] = gen
                    except StopAsyncIteration:
                        pass
922

923
    async def flush_cache(self) -> FlushCacheReqOutput:
Lianmin Zheng's avatar
Lianmin Zheng committed
924
        return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
Liangsheng Yin's avatar
Liangsheng Yin committed
925

926
927
    def abort_request(self, rid: str = "", abort_all: bool = False):
        if not abort_all and rid not in self.rid_to_state:
928
            return
929
        req = AbortReq(rid, abort_all)
930
        self.send_to_scheduler.send_pyobj(req)
931

932
933
934
        if self.enable_metrics:
            self.metrics_collector.observe_one_aborted_request()

935
936
937
    async def start_profile(
        self,
        output_dir: Optional[str] = None,
938
        start_step: Optional[int] = None,
939
940
        num_steps: Optional[int] = None,
        activities: Optional[List[str]] = None,
941
942
        with_stack: Optional[bool] = None,
        record_shapes: Optional[bool] = None,
943
        profile_by_stage: bool = False,
944
    ):
945
        self.auto_create_handle_loop()
946
947
        env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
        with_stack = False if with_stack is False or env_with_stack is False else True
948
949
950
        req = ProfileReq(
            type=ProfileReqType.START_PROFILE,
            output_dir=output_dir,
951
            start_step=start_step,
952
953
            num_steps=num_steps,
            activities=activities,
954
955
            with_stack=with_stack,
            record_shapes=record_shapes,
956
            profile_by_stage=profile_by_stage,
957
            profile_id=str(time.time()),
958
        )
959
960
961
        return await self._execute_profile(req)

    async def stop_profile(self):
962
        self.auto_create_handle_loop()
963
964
965
966
967
        req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
        return await self._execute_profile(req)

    async def _execute_profile(self, req: ProfileReq):
        result = (await self.profile_communicator(req))[0]
968
969
970
        if not result.success:
            raise RuntimeError(result.message)
        return result
971

972
    async def start_expert_distribution_record(self):
973
        self.auto_create_handle_loop()
974
        await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
975

976
    async def stop_expert_distribution_record(self):
977
        self.auto_create_handle_loop()
978
        await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
979

980
    async def dump_expert_distribution_record(self):
981
        self.auto_create_handle_loop()
982
        await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
983

984
    async def pause_generation(self):
985
986
        async with self._is_updating_cond:
            self._is_updating = True
987
988
989
            self.abort_request(abort_all=True)

    async def continue_generation(self):
990
991
992
        async with self._is_updating_cond:
            self._is_updating = False
            self._is_updating_cond.notify_all()
993

Chayenne's avatar
Chayenne committed
994
995
996
997
    async def update_weights_from_disk(
        self,
        obj: UpdateWeightFromDiskReqInput,
        request: Optional[fastapi.Request] = None,
998
    ) -> Tuple[bool, str]:
999
        self.auto_create_handle_loop()
1000
1001
1002
1003

        # default the load format to the server_args
        if obj.load_format is None:
            obj.load_format = self.server_args.load_format
1004
        logger.info("Start update_weights. Load format=%s", obj.load_format)
1005

1006
1007
1008
        if obj.abort_all_requests:
            self.abort_request(abort_all=True)

1009
        if True:  # Keep this redundant check to simplify some internal code sync
1010
1011
1012
1013
            # Hold the lock if it is not async. This means that weight sync
            # cannot run while requests are in progress.
            async with self.model_update_lock.writer_lock:
                return await self._wait_for_model_update_from_disk(obj)
1014

1015
1016
    async def _wait_for_model_update_from_disk(
        self, obj: UpdateWeightFromDiskReqInput
1017
    ) -> Tuple[bool, str]:
1018
1019
1020
1021
1022
1023
1024
1025
1026
        self.send_to_scheduler.send_pyobj(obj)
        self.model_update_result = asyncio.Future()
        if self.server_args.dp_size == 1:
            result = await self.model_update_result
            if result.success:
                self.served_model_name = obj.model_path
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
1027
            return result.success, result.message, result.num_paused_requests
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
        else:  # self.server_args.dp_size > 1
            self.model_update_tmp = []
            result = await self.model_update_result

            all_success = all([r.success for r in result])
            if all_success is True:
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
            all_message = [r.message for r in result]
            all_message = " | ".join(all_message)
1039
1040
            all_paused_requests = [r.num_paused_requests for r in result]
            return all_success, all_message, all_paused_requests
1041

1042
1043
1044
1045
    async def init_weights_update_group(
        self,
        obj: InitWeightsUpdateGroupReqInput,
        request: Optional[fastapi.Request] = None,
1046
    ) -> Tuple[bool, str]:
1047
        self.auto_create_handle_loop()
1048
1049
1050
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be 1 for init parameter update group"
1051
        result = (await self.init_weights_update_group_communicator(obj))[0]
1052
1053
1054
1055
1056
1057
        return result.success, result.message

    async def update_weights_from_distributed(
        self,
        obj: UpdateWeightsFromDistributedReqInput,
        request: Optional[fastapi.Request] = None,
1058
    ) -> Tuple[bool, str]:
1059
1060
        self.auto_create_handle_loop()
        assert (
1061
1062
            self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
        ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
1063

1064
1065
1066
        if obj.abort_all_requests:
            self.abort_request(abort_all=True)

1067
1068
1069
        # This means that weight sync
        # cannot run while requests are in progress.
        async with self.model_update_lock.writer_lock:
1070
            result = (await self.update_weights_from_distributed_communicator(obj))[0]
1071
            return result.success, result.message
1072

1073
1074
1075
1076
1077
1078
1079
    async def update_weights_from_tensor(
        self,
        obj: UpdateWeightsFromTensorReqInput,
        request: Optional[fastapi.Request] = None,
    ) -> Tuple[bool, str]:
        self.auto_create_handle_loop()
        assert (
1080
1081
            self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
        ), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
1082

1083
1084
1085
        if obj.abort_all_requests:
            self.abort_request(abort_all=True)

1086
1087
1088
1089
1090
1091
        # This means that weight sync
        # cannot run while requests are in progress.
        async with self.model_update_lock.writer_lock:
            result = (await self.update_weights_from_tensor_communicator(obj))[0]
            return result.success, result.message

1092
1093
1094
1095
1096
1097
1098
    async def load_lora_adapter(
        self,
        obj: LoadLoRAAdapterReqInput,
        _: Optional[fastapi.Request] = None,
    ) -> LoadLoRAAdapterReqOutput:
        self.auto_create_handle_loop()

1099
1100
1101
1102
1103
        try:
            if not self.server_args.enable_lora:
                raise ValueError(
                    "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
                )
1104

1105
1106
1107
1108
1109
1110
1111
1112
1113
            # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
            # with dp_size > 1.
            assert (
                self.server_args.dp_size == 1
            ), "dp_size must be 1 for dynamic lora loading"
            logger.info(
                "Start load Lora adapter. Lora name=%s, path=%s",
                obj.lora_name,
                obj.lora_path,
1114
1115
            )

1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
            async with self.lora_update_lock:
                if (
                    self.server_args.max_loaded_loras is not None
                    and self.lora_registry.num_registered_loras
                    >= self.server_args.max_loaded_loras
                ):
                    raise ValueError(
                        f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
                        f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
                        "Please unload some LoRA adapters before loading new ones."
                    )
1127

1128
1129
1130
1131
                # Generate new uniquely identifiable LoRARef object.
                new_adapter = LoRARef(
                    lora_name=obj.lora_name,
                    lora_path=obj.lora_path,
1132
                    pinned=obj.pinned,
1133
                )
1134

1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
                # Trigger the actual loading operation at the backend processes.
                obj.lora_id = new_adapter.lora_id
                result = (await self.update_lora_adapter_communicator(obj))[0]

                # Register the LoRA adapter only after loading is successful.
                if result.success:
                    await self.lora_registry.register(new_adapter)

                return result
        except ValueError as e:
            return LoadLoRAAdapterReqOutput(
                success=False,
                error_message=str(e),
            )
1149
1150
1151
1152
1153
1154
1155
1156

    async def unload_lora_adapter(
        self,
        obj: UnloadLoRAAdapterReqInput,
        _: Optional[fastapi.Request] = None,
    ) -> UnloadLoRAAdapterReqOutput:
        self.auto_create_handle_loop()

1157
1158
1159
1160
1161
        try:
            if not self.server_args.enable_lora:
                raise ValueError(
                    "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
                )
1162

1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
            assert (
                obj.lora_name is not None
            ), "lora_name must be provided to unload LoRA adapter"

            # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
            # with dp_size > 1.
            assert (
                self.server_args.dp_size == 1
            ), "dp_size must be 1 for dynamic lora loading"
            logger.info(
                "Start unload Lora adapter. Lora name=%s",
                obj.lora_name,
            )
1176

1177
1178
1179
1180
1181
            async with self.lora_update_lock:
                # Unregister the LoRA adapter from the registry to stop new requests for this adapter
                # from being started.
                lora_id = await self.lora_registry.unregister(obj.lora_name)
                obj.lora_id = lora_id
1182

1183
1184
1185
1186
                # Initiate the actual unloading operation at the backend processes only after all
                # ongoing requests using this LoRA adapter are finished.
                await self.lora_registry.wait_for_unload(lora_id)
                result = (await self.update_lora_adapter_communicator(obj))[0]
1187

1188
1189
                return result
        except ValueError as e:
1190
            return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
1191

1192
1193
1194
    async def get_weights_by_name(
        self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
    ):
1195
1196
1197
        self.auto_create_handle_loop()
        results = await self.get_weights_by_name_communicator(obj)
        all_parameters = [r.parameter for r in results]
1198
        if self.server_args.dp_size == 1:
1199
            return all_parameters[0]
1200
1201
1202
        else:
            return all_parameters

1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
    async def release_memory_occupation(
        self,
        obj: ReleaseMemoryOccupationReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.release_memory_occupation_communicator(obj)

    async def resume_memory_occupation(
        self,
        obj: ResumeMemoryOccupationReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.resume_memory_occupation_communicator(obj)

1219
1220
1221
1222
1223
1224
1225
1226
    async def slow_down(
        self,
        obj: SlowDownReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.slow_down_communicator(obj)

1227
1228
1229
    async def open_session(
        self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
    ):
1230
        self.auto_create_handle_loop()
1231

1232
1233
1234
1235
1236
        if obj.session_id is None:
            obj.session_id = uuid.uuid4().hex
        elif obj.session_id in self.session_futures:
            return None

1237
        self.send_to_scheduler.send_pyobj(obj)
1238
1239
1240
1241

        self.session_futures[obj.session_id] = asyncio.Future()
        session_id = await self.session_futures[obj.session_id]
        del self.session_futures[obj.session_id]
1242
1243
1244
1245
1246
1247
1248
        return session_id

    async def close_session(
        self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
    ):
        await self.send_to_scheduler.send_pyobj(obj)

1249
    async def get_internal_state(self) -> List[Dict[Any, Any]]:
1250
        req = GetInternalStateReq()
1251
        responses: List[GetInternalStateReqOutput] = (
1252
1253
            await self.get_internal_state_communicator(req)
        )
1254
1255
        # Many DP ranks
        return [res.internal_state for res in responses]
1256

1257
1258
1259
1260
1261
1262
1263
1264
    async def set_internal_state(
        self, obj: SetInternalStateReq
    ) -> SetInternalStateReqOutput:
        responses: List[SetInternalStateReqOutput] = (
            await self.set_internal_state_communicator(obj)
        )
        return [res.internal_state for res in responses]

1265
1266
1267
1268
1269
1270
1271
1272
    async def get_load(self) -> dict:
        # TODO(lsyin): fake load report server
        if not self.current_load_lock.locked():
            async with self.current_load_lock:
                internal_state = await self.get_internal_state()
                self.current_load = internal_state[0]["load"]
        return {"load": self.current_load}

1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
    def get_log_request_metadata(self):
        max_length = None
        skip_names = None
        out_skip_names = None
        if self.log_requests:
            if self.log_requests_level == 0:
                max_length = 1 << 30
                skip_names = set(
                    [
                        "text",
                        "input_ids",
                        "input_embeds",
                        "image_data",
                        "audio_data",
                        "lora_path",
1288
1289
1290
1291
1292
1293
1294
                        "sampling_params",
                    ]
                )
                out_skip_names = set(
                    [
                        "text",
                        "output_ids",
1295
                        "embedding",
1296
1297
1298
                    ]
                )
            elif self.log_requests_level == 1:
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
                max_length = 1 << 30
                skip_names = set(
                    [
                        "text",
                        "input_ids",
                        "input_embeds",
                        "image_data",
                        "audio_data",
                        "lora_path",
                    ]
                )
                out_skip_names = set(
                    [
                        "text",
                        "output_ids",
1314
                        "embedding",
1315
1316
                    ]
                )
1317
            elif self.log_requests_level == 2:
1318
1319
                max_length = 2048
            elif self.log_requests_level == 3:
1320
1321
1322
1323
1324
1325
1326
                max_length = 1 << 30
            else:
                raise ValueError(
                    f"Invalid --log-requests-level: {self.log_requests_level=}"
                )
        return max_length, skip_names, out_skip_names

1327
    def configure_logging(self, obj: ConfigureLoggingReq):
1328
1329
1330
1331
1332
1333
1334
1335
        if obj.log_requests is not None:
            self.log_requests = obj.log_requests
        if obj.log_requests_level is not None:
            self.log_requests_level = obj.log_requests_level
        if obj.dump_requests_folder is not None:
            self.dump_requests_folder = obj.dump_requests_folder
        if obj.dump_requests_threshold is not None:
            self.dump_requests_threshold = obj.dump_requests_threshold
1336
1337
        if obj.crash_dump_folder is not None:
            self.crash_dump_folder = obj.crash_dump_folder
1338
        logging.info(f"Config logging: {obj=}")
1339
        self.log_request_metadata = self.get_log_request_metadata()
1340

Lianmin Zheng's avatar
Lianmin Zheng committed
1341
    def create_abort_task(self, obj: GenerateReqInput):
1342
1343
        # Abort the request if the client is disconnected.
        async def abort_request():
Lianmin Zheng's avatar
Lianmin Zheng committed
1344
            await asyncio.sleep(2)
1345
1346
1347
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
1348
                for rid in obj.rid:
1349
1350
1351
1352
1353
1354
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

1355
    def auto_create_handle_loop(self):
1356
        if self.no_create_loop:
1357
1358
            return

1359
        self.no_create_loop = True
Lianmin Zheng's avatar
Lianmin Zheng committed
1360
        loop = asyncio.get_event_loop()
1361
1362
1363
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.handle_loop))
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1364

1365
1366
        self.event_loop = loop

1367
1368
1369
1370
        # We cannot add signal handler when the tokenizer manager is not in
        # the main thread due to the CPython limitation.
        if threading.current_thread() is threading.main_thread():
            signal_handler = SignalHandler(self)
1371
1372
1373
1374
1375
            loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
            # Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
            loop.add_signal_handler(
                signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
            )
1376
1377
1378
1379
1380
1381
        else:
            logger.warning(
                "Signal handler is not added because the tokenizer manager is "
                "not in the main thread. This disables graceful shutdown of the "
                "tokenizer manager when SIGTERM is received."
            )
1382
1383
1384
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
        )
1385

1386
1387
1388
1389
1390
1391
    def dump_requests_before_crash(self):
        if self.crash_dump_performed:
            logger.info(
                "SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
            )
            return
1392

1393
1394
1395
        if not self.crash_dump_folder:
            return

1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
        logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
        self.crash_dump_performed = True

        # Check if NFS directory is available
        # expected_nfs_dir = "/" + self.crash_dump_folder.lstrip("/").split("/")[0]
        # use_nfs_dir = os.path.isdir(expected_nfs_dir) and os.access(
        #     expected_nfs_dir, os.W_OK
        # )
        use_nfs_dir = False
        if not use_nfs_dir:
            logger.error(
                f"Expected NFS directory is not available or writable. Uploading to GCS."
            )

1410
1411
1412
1413
1414
1415
1416
1417
1418
        data_to_dump = []
        if self.crash_dump_request_list:
            data_to_dump.extend(self.crash_dump_request_list)

        # Add unfinished requests from rid_to_state
        unfinished_requests = []
        for rid, state in self.rid_to_state.items():
            if not state.finished:
                unfinished_requests.append(
1419
1420
1421
1422
1423
1424
                    (
                        state.obj,
                        state.out_list[-1] if state.out_list else {},
                        state.created_time,
                        time.time(),
                    )
1425
1426
1427
1428
1429
1430
1431
                )
        if unfinished_requests:
            data_to_dump.extend(unfinished_requests)

        if not data_to_dump:
            return

1432
        object_name = f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl'
1433
1434
1435
        filename = os.path.join(
            self.crash_dump_folder,
            os.getenv("HOSTNAME", None),
1436
            object_name,
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
        )

        os.makedirs(os.path.dirname(filename), exist_ok=True)
        # Include server_args in the dump
        data_to_dump_with_server_args = {
            "server_args": self.server_args,
            "requests": data_to_dump,
        }
        with open(filename, "wb") as f:
            pickle.dump(data_to_dump_with_server_args, f)
        logger.error(
            f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}"
        )

1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
        def _upload_file_to_gcs(bucket_name, source_file_path, object_name):
            from google.cloud import storage

            client = storage.Client()
            bucket = client.bucket(bucket_name)
            blob = bucket.blob(object_name)
            blob.upload_from_filename(source_file_path, if_generation_match=0)
            logger.error(
                f"Successfully uploaded {source_file_path} to gs://{bucket_name}/{object_name}"
            )

        if not use_nfs_dir:
            _upload_file_to_gcs(
                "sglang_crash_dump",
                filename,
                os.getenv("HOSTNAME", None) + "/" + object_name,
            )

1469
1470
    async def sigterm_watchdog(self):
        while not self.gracefully_exit:
1471
            await asyncio.sleep(5)
1472

1473
        # Drain requests
1474
        while True:
1475
            remain_num_req = len(self.rid_to_state)
1476

1477
            if self.health_check_failed:
1478
                # if health check failed, we should exit immediately
1479
1480
1481
1482
                logger.error(
                    "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
                    remain_num_req,
                )
1483
                self.dump_requests_before_crash()
1484
                break
1485
1486
1487
1488
1489
1490
1491
1492

            elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
                # if force shutdown flag set, exit immediately
                logger.error(
                    "Signal SIGTERM received while force shutdown flag set. Force exiting... remaining number of requests: %d",
                    remain_num_req,
                )
                break
1493

1494
            logger.info(
1495
                f"Gracefully exiting... remaining number of requests {remain_num_req}"
1496
1497
1498
1499
            )
            if remain_num_req > 0:
                await asyncio.sleep(5)
            else:
1500
                self.dump_requests_before_crash()
1501
1502
                break

1503
        kill_process_tree(os.getpid(), include_parent=True)
1504
        sys.exit(0)
1505

Lianmin Zheng's avatar
Lianmin Zheng committed
1506
    async def handle_loop(self):
1507
        """The event loop that handles requests"""
1508

Lianmin Zheng's avatar
Lianmin Zheng committed
1509
        while True:
1510
            recv_obj = await self.recv_from_detokenizer.recv_pyobj()
1511
            self._result_dispatcher(recv_obj)
1512
            self.last_receive_tstamp = time.time()
1513

1514
    def _handle_batch_output(
1515
1516
1517
1518
        self,
        recv_obj: Union[
            BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
        ],
1519
1520
1521
1522
    ):
        for i, rid in enumerate(recv_obj.rids):
            state = self.rid_to_state.get(rid, None)
            if state is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1523
1524
1525
                logger.error(
                    f"Received output for {rid=} but the state was deleted in TokenizerManager."
                )
1526
                continue
1527

1528
            # Build meta_info and return value
1529
            meta_info = {
1530
                "id": rid,
1531
1532
1533
1534
1535
1536
1537
                "finish_reason": recv_obj.finished_reasons[i],
                "prompt_tokens": recv_obj.prompt_tokens[i],
            }

            if getattr(state.obj, "return_logprob", False):
                self.convert_logprob_style(
                    meta_info,
1538
                    state,
1539
                    state.obj.top_logprobs_num,
1540
                    state.obj.token_ids_logprob,
1541
1542
                    state.obj.return_text_in_logprobs
                    and not self.server_args.skip_tokenizer_init,
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
                    recv_obj,
                    i,
                )

            if not isinstance(recv_obj, BatchEmbeddingOut):
                meta_info.update(
                    {
                        "completion_tokens": recv_obj.completion_tokens[i],
                        "cached_tokens": recv_obj.cached_tokens[i],
                    }
                )

1555
            if getattr(recv_obj, "output_hidden_states", None):
1556
1557
1558
                meta_info["hidden_states"] = recv_obj.output_hidden_states[i]

            if isinstance(recv_obj, BatchStrOut):
1559
                state.text += recv_obj.output_strs[i]
1560
1561
1562
1563
1564
1565
1566
1567
                if state.obj.stream:
                    state.output_ids.extend(recv_obj.output_ids[i])
                    output_token_ids = state.output_ids[state.last_output_offset :]
                    state.last_output_offset = len(state.output_ids)
                else:
                    state.output_ids.extend(recv_obj.output_ids[i])
                    output_token_ids = state.output_ids.copy()

1568
                out_dict = {
1569
                    "text": state.text,
1570
                    "output_ids": output_token_ids,
1571
1572
1573
                    "meta_info": meta_info,
                }
            elif isinstance(recv_obj, BatchTokenIDOut):
1574
                if self.server_args.stream_output and state.obj.stream:
1575
1576
1577
                    state.output_ids.extend(recv_obj.output_ids[i])
                    output_token_ids = state.output_ids[state.last_output_offset :]
                    state.last_output_offset = len(state.output_ids)
1578
                else:
1579
                    state.output_ids.extend(recv_obj.output_ids[i])
1580
                    output_token_ids = state.output_ids.copy()
1581

1582
                out_dict = {
1583
                    "output_ids": output_token_ids,
1584
1585
                    "meta_info": meta_info,
                }
1586
            elif isinstance(recv_obj, BatchMultimodalOut):
1587
                raise NotImplementedError("BatchMultimodalOut not implemented")
1588
1589
1590
1591
1592
1593
1594
1595
            else:
                assert isinstance(recv_obj, BatchEmbeddingOut)
                out_dict = {
                    "embedding": recv_obj.embeddings[i],
                    "meta_info": meta_info,
                }

            state.finished = recv_obj.finished_reasons[i] is not None
1596
1597
1598
1599
1600
            if state.finished:
                if self.server_args.speculative_algorithm:
                    meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
                state.finished_time = time.time()
                meta_info["e2e_latency"] = state.finished_time - state.created_time
Lianmin Zheng's avatar
Lianmin Zheng committed
1601
                del self.rid_to_state[rid]
1602
1603

            state.out_list.append(out_dict)
1604
1605
            state.event.set()

1606
            # Log metrics and dump
1607
1608
1609
1610
            if self.enable_metrics and state.obj.log_metrics:
                self.collect_metrics(state, recv_obj, i)
            if self.dump_requests_folder and state.finished and state.obj.log_metrics:
                self.dump_requests(state, out_dict)
1611
1612
            if self.crash_dump_folder and state.finished and state.obj.log_metrics:
                self.record_request_for_crash_dump(state, out_dict)
1613
1614
1615
1616

    def convert_logprob_style(
        self,
        meta_info: dict,
1617
        state: ReqState,
1618
        top_logprobs_num: int,
1619
        token_ids_logprob: List[int],
1620
1621
1622
1623
        return_text_in_logprobs: bool,
        recv_obj: BatchStrOut,
        recv_obj_index: int,
    ):
1624
1625
1626
        if recv_obj.input_token_logprobs_val is None:
            return

1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
        if len(recv_obj.input_token_logprobs_val) > 0:
            state.input_token_logprobs_val.extend(
                recv_obj.input_token_logprobs_val[recv_obj_index]
            )
            state.input_token_logprobs_idx.extend(
                recv_obj.input_token_logprobs_idx[recv_obj_index]
            )
        state.output_token_logprobs_val.extend(
            recv_obj.output_token_logprobs_val[recv_obj_index]
        )
        state.output_token_logprobs_idx.extend(
            recv_obj.output_token_logprobs_idx[recv_obj_index]
        )
1640
        meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
1641
1642
            state.input_token_logprobs_val,
            state.input_token_logprobs_idx,
1643
1644
1645
            return_text_in_logprobs,
        )
        meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
1646
1647
            state.output_token_logprobs_val,
            state.output_token_logprobs_idx,
1648
1649
1650
1651
            return_text_in_logprobs,
        )

        if top_logprobs_num > 0:
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
            if len(recv_obj.input_top_logprobs_val) > 0:
                state.input_top_logprobs_val.extend(
                    recv_obj.input_top_logprobs_val[recv_obj_index]
                )
                state.input_top_logprobs_idx.extend(
                    recv_obj.input_top_logprobs_idx[recv_obj_index]
                )
            state.output_top_logprobs_val.extend(
                recv_obj.output_top_logprobs_val[recv_obj_index]
            )
            state.output_top_logprobs_idx.extend(
                recv_obj.output_top_logprobs_idx[recv_obj_index]
            )
1665
            meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1666
1667
                state.input_top_logprobs_val,
                state.input_top_logprobs_idx,
1668
1669
1670
                return_text_in_logprobs,
            )
            meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1671
1672
                state.output_top_logprobs_val,
                state.output_top_logprobs_idx,
1673
1674
1675
                return_text_in_logprobs,
            )

1676
        if token_ids_logprob is not None:
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
            if len(recv_obj.input_token_ids_logprobs_val) > 0:
                state.input_token_ids_logprobs_val.extend(
                    recv_obj.input_token_ids_logprobs_val[recv_obj_index]
                )
                state.input_token_ids_logprobs_idx.extend(
                    recv_obj.input_token_ids_logprobs_idx[recv_obj_index]
                )
            state.output_token_ids_logprobs_val.extend(
                recv_obj.output_token_ids_logprobs_val[recv_obj_index]
            )
            state.output_token_ids_logprobs_idx.extend(
                recv_obj.output_token_ids_logprobs_idx[recv_obj_index]
            )
1690
            meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
1691
1692
                state.input_token_ids_logprobs_val,
                state.input_token_ids_logprobs_idx,
1693
1694
1695
1696
                return_text_in_logprobs,
            )
            meta_info["output_token_ids_logprobs"] = (
                self.detokenize_top_logprobs_tokens(
1697
1698
                    state.output_token_ids_logprobs_val,
                    state.output_token_ids_logprobs_idx,
1699
1700
1701
1702
                    return_text_in_logprobs,
                )
            )

1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
    def detokenize_logprob_tokens(
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
    ):
        if not decode_to_text:
            return [
                (logprob, token_id, None)
                for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
            ]
        else:
            assert self.tokenizer is not None
            token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
            return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))

    def detokenize_top_logprobs_tokens(
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
    ):
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
        ret = []
        for i in range(len(token_logprobs_val)):
            if token_logprobs_val[i]:
                ret.append(
                    self.detokenize_logprob_tokens(
                        token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
                    )
                )
            else:
                ret.append(None)
        return ret

    def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
        completion_tokens = (
            recv_obj.completion_tokens[i]
            if getattr(recv_obj, "completion_tokens", None)
            else 0
        )

1746
1747
1748
1749
        if (
            state.first_token_time == 0.0
            and self.disaggregation_mode != DisaggregationMode.PREFILL
        ):
1750
1751
            state.first_token_time = state.last_time = time.time()
            state.last_completion_tokens = completion_tokens
1752
1753
1754
1755
            self.metrics_collector.observe_time_to_first_token(
                state.first_token_time - state.created_time
            )
        else:
1756
1757
1758
1759
1760
1761
1762
            num_new_tokens = completion_tokens - state.last_completion_tokens
            if num_new_tokens:
                new_time = time.time()
                interval = new_time - state.last_time
                self.metrics_collector.observe_inter_token_latency(
                    interval,
                    num_new_tokens,
1763
                )
1764
1765
                state.last_time = new_time
                state.last_completion_tokens = completion_tokens
1766
1767

        if state.finished:
1768
1769
1770
1771
1772
1773
            has_grammar = (
                state.obj.sampling_params.get("json_schema", None)
                or state.obj.sampling_params.get("regex", None)
                or state.obj.sampling_params.get("ebnf", None)
                or state.obj.sampling_params.get("structural_tag", None)
            )
1774
            self.metrics_collector.observe_one_finished_request(
1775
1776
                recv_obj.prompt_tokens[i],
                completion_tokens,
1777
                recv_obj.cached_tokens[i],
1778
                state.finished_time - state.created_time,
1779
                has_grammar,
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
            )

    def dump_requests(self, state: ReqState, out_dict: dict):
        self.dump_request_list.append(
            (state.obj, out_dict, state.created_time, time.time())
        )

        if len(self.dump_request_list) >= self.dump_requests_threshold:
            filename = os.path.join(
                self.dump_requests_folder,
                datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
            )
1792
1793
1794
1795
1796
            self._dump_data_to_file(
                data_list=self.dump_request_list,
                filename=filename,
                log_message=f"Dump {len(self.dump_request_list)} requests to {filename}",
            )
1797
1798
            self.dump_request_list = []

1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
    def record_request_for_crash_dump(self, state: ReqState, out_dict: dict):
        current_time = time.time()
        self.crash_dump_request_list.append(
            (state.obj, out_dict, state.created_time, current_time)
        )
        # Remove requests older than 5 minutes based on finish time
        while (
            self.crash_dump_request_list
            and current_time - self.crash_dump_request_list[0][3] >= 300
        ):
            self.crash_dump_request_list.popleft()

1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
    def _dump_data_to_file(
        self, data_list: List[Tuple], filename: str, log_message: str
    ):
        logger.info(log_message)
        to_dump_with_server_args = {
            "server_args": self.server_args,
            "requests": data_list.copy(),
        }

        def background_task():
            os.makedirs(os.path.dirname(filename), exist_ok=True)
            with open(filename, "wb") as f:
                pickle.dump(to_dump_with_server_args, f)

        asyncio.create_task(asyncio.to_thread(background_task))

Lianmin Zheng's avatar
Lianmin Zheng committed
1827
    def _handle_abort_req(self, recv_obj):
1828
1829
        if is_health_check_generate_req(recv_obj):
            return
1830
1831
        state = self.rid_to_state[recv_obj.rid]
        state.finished = True
1832
1833
1834
1835
1836
1837
1838
1839
1840
        if recv_obj.finished_reason:
            out = {
                "meta_info": {
                    "id": recv_obj.rid,
                    "finish_reason": recv_obj.finished_reason,
                },
            }
        else:
            out = {
1841
1842
                "text": "",
                "meta_info": {
1843
                    "id": recv_obj.rid,
1844
1845
1846
1847
1848
1849
1850
1851
                    "finish_reason": {
                        "type": "abort",
                        "message": "Abort before prefill",
                    },
                    "prompt_tokens": 0,
                    "completion_tokens": 0,
                },
            }
1852
        state.out_list.append(out)
1853
        state.event.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
1854

1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
    def _handle_open_session_req_output(self, recv_obj):
        self.session_futures[recv_obj.session_id].set_result(
            recv_obj.session_id if recv_obj.success else None
        )

    def _handle_update_weights_from_disk_req_output(self, recv_obj):
        if self.server_args.dp_size == 1:
            self.model_update_result.set_result(recv_obj)
        else:  # self.server_args.dp_size > 1
            self.model_update_tmp.append(recv_obj)
1865
            # set future if the all results are received
1866
1867
1868
            if len(self.model_update_tmp) == self.server_args.dp_size:
                self.model_update_result.set_result(self.model_update_tmp)

1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
    async def score_request(
        self,
        query: Optional[Union[str, List[int]]] = None,
        items: Optional[Union[str, List[str], List[List[int]]]] = None,
        label_token_ids: Optional[List[int]] = None,
        apply_softmax: bool = False,
        item_first: bool = False,
        request: Optional[Any] = None,
    ) -> List[List[float]]:
        """
        See Engine.score() for more details.
        """
        if label_token_ids is None:
            raise ValueError("label_token_ids must be provided")

        if self.tokenizer is not None:
            vocab_size = self.tokenizer.vocab_size
            for token_id in label_token_ids:
                if token_id >= vocab_size:
                    raise ValueError(
                        f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
                    )

        # Handle string or tokenized query/items
        if isinstance(query, str) and (
            isinstance(items, str)
            or (isinstance(items, list) and (not items or isinstance(items[0], str)))
        ):
            # Both query and items are text
            items_list = [items] if isinstance(items, str) else items
            if item_first:
                prompts = [f"{item}{query}" for item in items_list]
            else:
                prompts = [f"{query}{item}" for item in items_list]
            batch_request = GenerateReqInput(
                text=prompts,
                return_logprob=True,
                token_ids_logprob=label_token_ids,
                stream=False,
                sampling_params={"max_new_tokens": 1},
            )
        elif (
            isinstance(query, list)
            and isinstance(items, list)
            and items
            and isinstance(items[0], list)
        ):
            # Both query and items are token IDs
            if item_first:
                input_ids_list = [item + query for item in items]
            else:
                input_ids_list = [query + item for item in items]
            batch_request = GenerateReqInput(
                input_ids=input_ids_list,
                return_logprob=True,
                token_ids_logprob=label_token_ids,
                stream=False,
                sampling_params={"max_new_tokens": 1},
            )
        else:
            raise ValueError(
                "Invalid combination of query/items types for score_request."
            )

        results = await self.generate_request(batch_request, request).__anext__()
        scores = []

        for result in results:
            # Get logprobs for each token
            logprobs = {}
            for logprob, token_id, _ in result["meta_info"].get(
                "output_token_ids_logprobs", []
            )[0]:
                if token_id in label_token_ids:
                    logprobs[token_id] = logprob

            # Get scores in order of label_token_ids
            score_list = [
                logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
            ]

            # Apply softmax to logprobs if needed
            if apply_softmax:
                score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
            else:
                # Convert logprobs to probabilities if not using softmax
                score_list = [
                    math.exp(x) if x != float("-inf") else 0.0 for x in score_list
                ]

            scores.append(score_list)

        return scores

1963

1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
class ServerStatus(Enum):
    Up = "Up"
    Starting = "Starting"
    UnHealthy = "UnHealthy"
    Crashed = "Crashed"

    def is_healthy(self) -> bool:
        return self == ServerStatus.Up


1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
    is_cross_node = server_args.dist_init_addr

    if is_cross_node:
        # Fallback to default CPU transport for multi-node
        return "default"
    else:
        return "cuda_ipc"


1984
1985
1986
1987
1988
1989
1990
1991
1992
async def print_exception_wrapper(func):
    """
    Sometimes an asyncio function does not print exception.
    We do another wrapper to handle the exception.
    """
    try:
        await func()
    except Exception:
        traceback = get_exception_traceback()
1993
        logger.error(f"TokenizerManager hit an exception: {traceback}")
1994
1995
        if hasattr(func, "__self__") and isinstance(func.__self__, TokenizerManager):
            func.__self__.dump_requests_before_crash()
1996
1997
1998
1999
        kill_process_tree(os.getpid(), include_parent=True)
        sys.exit(1)


2000
class SignalHandler:
2001
    def __init__(self, tokenizer_manager: TokenizerManager):
2002
        self.tokenizer_manager = tokenizer_manager
2003

2004
    def sigterm_handler(self, signum=None, frame=None):
2005
2006
2007
        logger.warning(
            f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
        )
2008
        self.tokenizer_manager.gracefully_exit = True
2009

2010
2011
2012
2013
    def running_phase_sigquit_handler(self, signum=None, frame=None):
        logger.error(
            "Received sigquit from a child process. It usually means the child failed."
        )
2014
        self.tokenizer_manager.dump_requests_before_crash()
2015
2016
        kill_process_tree(os.getpid())

2017
2018
2019
2020
2021

T = TypeVar("T")


class _Communicator(Generic[T]):
2022
2023
    """Note: The communicator now only run up to 1 in-flight request at any time."""

2024
2025
2026
    def __init__(self, sender, fan_out: int):
        self._sender = sender
        self._fan_out = fan_out
2027
        self._result_event: Optional[asyncio.Event] = None
2028
        self._result_values: Optional[List[T]] = None
2029
        self._ready_queue: Deque[asyncio.Future] = deque()
2030
2031

    async def __call__(self, obj):
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
        ready_event = asyncio.Event()
        if self._result_event is not None or len(self._ready_queue) > 0:
            self._ready_queue.append(ready_event)
            await ready_event.wait()
            assert self._result_event is None
            assert self._result_values is None

        if obj:
            self._sender.send_pyobj(obj)

        self._result_event = asyncio.Event()
2043
        self._result_values = []
2044
        await self._result_event.wait()
2045
        result_values = self._result_values
2046
2047
2048
2049
2050
        self._result_event = self._result_values = None

        if len(self._ready_queue) > 0:
            self._ready_queue.popleft().set()

2051
2052
2053
2054
2055
        return result_values

    def handle_recv(self, recv_obj: T):
        self._result_values.append(recv_obj)
        if len(self._result_values) == self._fan_out:
2056
            self._result_event.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
2057
2058
2059
2060
2061
2062
2063


# Note: request abort handling logic
# We should handle all of the following cases correctly.
#
# | entrypoint | is_streaming | status          | abort engine    | cancel asyncio task   | rid_to_state                |
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
2064
# | http       | yes          | validation      | background task | fast api              | del in _handle_abort_req    |
Lianmin Zheng's avatar
Lianmin Zheng committed
2065
2066
# | http       | yes          | waiting queue   | background task | fast api              | del in _handle_abort_req    |
# | http       | yes          | running         | background task | fast api              | del in _handle_batch_output |
2067
# | http       | no           | validation      | http exception  | http exception        | del in _handle_abort_req    |
Lianmin Zheng's avatar
Lianmin Zheng committed
2068
2069
2070
# | http       | no           | waiting queue   | type 1          | type 1 exception      | del in _handle_abort_req    |
# | http       | no           | running         | type 3          | type 3 exception      | del in _handle_batch_output |
#