tokenizer_manager.py 85.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
"""TokenizerManager is a process that tokenizes the text."""
15

Lianmin Zheng's avatar
Lianmin Zheng committed
16
import asyncio
17
18
import copy
import dataclasses
19
import json
20
import logging
21
import math
Lianmin Zheng's avatar
Lianmin Zheng committed
22
import os
23
import pickle
24
25
import signal
import sys
26
import threading
27
import time
28
import uuid
29
from collections import deque
fzyzcjy's avatar
fzyzcjy committed
30
from contextlib import nullcontext
31
from datetime import datetime
32
from enum import Enum
33
from http import HTTPStatus
34
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
35

36
import fastapi
37
import torch
Lianmin Zheng's avatar
Lianmin Zheng committed
38
39
40
import uvloop
import zmq
import zmq.asyncio
41
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
42

43
from sglang.srt.aio_rwlock import RWLock
44
from sglang.srt.configs.model_config import ModelConfig
45
from sglang.srt.disaggregation.utils import DisaggregationMode
46
from sglang.srt.lora.lora_registry import LoRARegistry
47
from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
48
from sglang.srt.managers.disagg_service import start_disagg_service
Lianmin Zheng's avatar
Lianmin Zheng committed
49
from sglang.srt.managers.io_struct import (
50
    AbortReq,
51
52
53
54
    BatchEmbeddingOutput,
    BatchMultimodalOutput,
    BatchStrOutput,
    BatchTokenIDOutput,
55
56
    BatchTokenizedEmbeddingReqInput,
    BatchTokenizedGenerateReqInput,
57
    ConfigureLoggingReq,
58
    EmbeddingReqInput,
59
    FreezeGCReq,
Lianmin Zheng's avatar
Lianmin Zheng committed
60
    GenerateReqInput,
61
    GetLoadReqInput,
62
    HealthCheckOutput,
63
    MultiTokenizerWrapper,
64
    OpenSessionReqOutput,
65
66
67
    SessionParams,
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
68
69
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
70
    WatchLoadUpdateReq,
Lianmin Zheng's avatar
Lianmin Zheng committed
71
)
Mick's avatar
Mick committed
72
from sglang.srt.managers.mm_utils import TensorTransportMode
73
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
74
from sglang.srt.managers.scheduler import is_health_check_generate_req
fzyzcjy's avatar
fzyzcjy committed
75
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
76
from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin
77
78
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
79
from sglang.srt.server_args import PortArgs, ServerArgs
80
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
81
82
83
84
85
86
87
from sglang.srt.tracing.trace import (
    trace_get_proc_propagate_context,
    trace_req_finish,
    trace_req_start,
    trace_slice_end,
    trace_slice_start,
)
88
from sglang.srt.utils import (
89
    configure_gc_warning,
90
    dataclass_to_string_truncated,
91
    freeze_gc,
92
    get_bool_env_var,
93
    get_origin_rid,
94
95
96
    get_zmq_socket,
    kill_process_tree,
)
97
98
99
100
101
from sglang.srt.utils.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
102
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
Lianmin Zheng's avatar
Lianmin Zheng committed
103
104
105

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

106
107
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
108

109
110
111
112
@dataclasses.dataclass
class ReqState:
    """Store the state a request."""

113
    out_list: List[Dict[Any, Any]]
114
115
    finished: bool
    event: asyncio.Event
116
    obj: Union[GenerateReqInput, EmbeddingReqInput]
117
118
119

    # For metrics
    created_time: float
120
121
122
123
    finished_time: float = 0.0
    first_token_time: float = 0.0
    last_time: float = 0.0
    last_completion_tokens: int = 1
124
125
126

    # For streaming output
    last_output_offset: int = 0
127

128
    # For incremental state update.
129
    # TODO(lianmin): do not initialize some lists if not needed.
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    text: str = ""
    output_ids: List[int] = dataclasses.field(default_factory=list)
    input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
    input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
    output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
    output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
    input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
    input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
    output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
    output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
    input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
    input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
    output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
    output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
144
145


146
class TokenizerManager(TokenizerCommunicatorMixin):
147
    """TokenizerManager is a process that tokenizes the text."""
148

Lianmin Zheng's avatar
Lianmin Zheng committed
149
150
151
152
153
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
154
        # Parse args
Liangsheng Yin's avatar
Liangsheng Yin committed
155
        self.server_args = server_args
156
        self.enable_metrics = server_args.enable_metrics
157
        self.log_requests = server_args.log_requests
158
        self.log_requests_level = server_args.log_requests_level
159
160
161
162
163
        self.preferred_sampling_params = (
            json.loads(server_args.preferred_sampling_params)
            if server_args.preferred_sampling_params
            else None
        )
164
        self.crash_dump_folder = server_args.crash_dump_folder
165
        self.enable_trace = server_args.enable_trace
Lianmin Zheng's avatar
Lianmin Zheng committed
166

167
        # Read model args
Lianmin Zheng's avatar
Lianmin Zheng committed
168
        self.model_path = server_args.model_path
169
        self.served_model_name = server_args.served_model_name
170
        self.model_config = ModelConfig.from_server_args(server_args)
171
        self.is_generation = self.model_config.is_generation
172
        self.is_image_gen = self.model_config.is_image_gen
173
174
        self.context_len = self.model_config.context_len
        self.image_token_id = self.model_config.image_token_id
175
        self.max_req_input_len = None  # Will be set later in engine.py
176

177
178
179
180
181
182
183
184
        speculative_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
        self.reserve_input_token_num = (
            0
            if speculative_algorithm.is_none()
            else server_args.speculative_num_draft_tokens
        )
185
186
        # Initialize delimiter text for multi-item scoring (will be set after tokenizer is loaded)
        self.multi_item_delimiter_text = None
187

188
        if self.model_config.is_multimodal:
189
            import_processors("sglang.srt.multimodal.processors")
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
            try:
                _processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                    use_fast=not server_args.disable_fast_image_processor,
                )
            except ValueError as e:
                error_message = str(e)
                if "does not have a slow version" in error_message:
                    logger.info(
                        f"Processor {server_args.tokenizer_path} does not have a slow version. Automatically use fast version"
                    )
                    _processor = get_processor(
                        server_args.tokenizer_path,
                        tokenizer_mode=server_args.tokenizer_mode,
                        trust_remote_code=server_args.trust_remote_code,
                        revision=server_args.revision,
                        use_fast=True,
                    )
                else:
                    raise e
Mick's avatar
Mick committed
213
            transport_mode = _determine_tensor_transport_mode(self.server_args)
214
215

            # We want to parallelize the image pre-processing so we create an executor for it
Mick's avatar
Mick committed
216
            # We create mm_processor for any skip_tokenizer_init to make sure we still encode
217
            # images even with skip_tokenizer_init=False.
Mick's avatar
Mick committed
218
            self.mm_processor = get_mm_processor(
Mick's avatar
Mick committed
219
                self.model_config.hf_config, server_args, _processor, transport_mode
220
221
222
223
224
225
            )

            if server_args.skip_tokenizer_init:
                self.tokenizer = self.processor = None
            else:
                self.processor = _processor
xm:D's avatar
xm:D committed
226
                self.tokenizer = get_tokenizer_from_processor(self.processor)
227
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
228
                self._initialize_multi_item_delimiter_text()
229
        else:
230
            self.mm_processor = self.processor = None
231

232
            if server_args.skip_tokenizer_init:
233
                self.tokenizer = None
234
235
236
237
238
239
240
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )
241
                self._initialize_multi_item_delimiter_text()
242
243
244
245
246
247
248
249
250
251
252
253
        # Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
        if (
            server_args.enable_dynamic_batch_tokenizer
            and not server_args.skip_tokenizer_init
        ):
            self.async_dynamic_batch_tokenizer = AsyncDynamicbatchTokenizer(
                self.tokenizer,
                max_batch_size=server_args.dynamic_batch_tokenizer_batch_size,
                batch_wait_timeout_s=server_args.dynamic_batch_tokenizer_batch_timeout,
            )
        else:
            self.async_dynamic_batch_tokenizer = None
Lianmin Zheng's avatar
Lianmin Zheng committed
254

255
        # Init inter-process communication
256
        context = zmq.asyncio.Context(2)
257
258
259
        self.recv_from_detokenizer = get_zmq_socket(
            context, zmq.PULL, port_args.tokenizer_ipc_name, True
        )
260
261
262
263
264
265
266
267
268
        if self.server_args.tokenizer_worker_num > 1:
            # Use tokenizer_worker_ipc_name in multi-tokenizer mode
            self.send_to_scheduler = get_zmq_socket(
                context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
            )
        else:
            self.send_to_scheduler = get_zmq_socket(
                context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
            )
269

270
        # Request states
271
        self.no_create_loop = False
272
        self.rid_to_state: Dict[str, ReqState] = {}
273
274
275
        self.asyncio_tasks = set()

        # Health check
Lianmin Zheng's avatar
Lianmin Zheng committed
276
        self.server_status = ServerStatus.Starting
277
278
        self.gracefully_exit = False
        self.last_receive_tstamp = 0
279
280

        # Dumping
281
282
283
        self.dump_requests_folder = ""  # By default do not dump
        self.dump_requests_threshold = 1000
        self.dump_request_list: List[Tuple] = []
284
        self.log_request_metadata = self.get_log_request_metadata()
285
286
287
288
        self.crash_dump_request_list: deque[Tuple] = deque()
        self.crash_dump_performed = False  # Flag to ensure dump is only called once

        # Session
289
        self.session_futures = {}  # session_id -> asyncio event
Lianmin Zheng's avatar
Lianmin Zheng committed
290

291
        # Weight updates
292
293
294
295
296
        # The event to notify the weight sync is finished.
        self.model_update_lock = RWLock()
        self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
            None
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
297
298
        self.is_pause = False
        self.is_pause_cond = asyncio.Condition()
299

300
301
302
303
304
        # LoRA
        # Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
        # The registry dynamically updates as adapters are loaded / unloaded during runtime. It
        # serves as the source of truth for available adapters and maps user-friendly LoRA names
        # to internally used unique LoRA IDs.
305
        self.lora_registry = LoRARegistry(self.server_args.lora_paths)
306
307
308
309
310
        # Lock to serialize LoRA update operations.
        # Please note that, unlike `model_update_lock`, this does not block inference, allowing
        # LoRA updates and inference to overlap.
        self.lora_update_lock = asyncio.Lock()

311
312
313
314
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
        self.bootstrap_server = start_disagg_service(self.server_args)
315

316
317
318
        # For load balancing
        self.current_load = 0
        self.current_load_lock = asyncio.Lock()
319
320
321

        # Metrics
        if self.enable_metrics:
322
323
324
325
            labels = {
                "model_name": self.server_args.served_model_name,
                # TODO: Add lora name/path in the future,
            }
326
327
            if server_args.tokenizer_metrics_allowed_custom_labels:
                for label in server_args.tokenizer_metrics_allowed_custom_labels:
328
                    labels[label] = ""
329
            self.metrics_collector = TokenizerMetricsCollector(
330
                server_args=server_args,
331
                labels=labels,
332
333
334
335
                bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
                bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
                bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
                collect_tokens_histogram=self.server_args.collect_tokens_histogram,
336
337
            )

338
339
340
341
        # Configure GC warning
        if self.server_args.gc_warning_threshold_secs > 0.0:
            configure_gc_warning(self.server_args.gc_warning_threshold_secs)

342
        self._result_dispatcher = TypeBasedDispatcher(
343
            [
344
                (
345
                    (
346
347
348
349
                        BatchStrOutput,
                        BatchEmbeddingOutput,
                        BatchTokenIDOutput,
                        BatchMultimodalOutput,
350
                    ),
351
                    self._handle_batch_output,
352
                ),
Lianmin Zheng's avatar
Lianmin Zheng committed
353
                (AbortReq, self._handle_abort_req),
354
355
356
357
358
                (OpenSessionReqOutput, self._handle_open_session_req_output),
                (
                    UpdateWeightFromDiskReqOutput,
                    self._handle_update_weights_from_disk_req_output,
                ),
359
360
361
362
                (
                    FreezeGCReq,
                    lambda x: None,
                ),  # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
363
                (HealthCheckOutput, lambda x: None),
364
365
366
            ]
        )

367
368
        self.init_communicators(server_args)

369
    async def generate_request(
370
        self,
371
        obj: Union[GenerateReqInput, EmbeddingReqInput],
372
        request: Optional[fastapi.Request] = None,
373
    ):
374
        created_time = time.time()
375
        self.auto_create_handle_loop()
376
        obj.normalize_batch_and_arguments()
377

378
379
380
381
382
383
384
385
386
        if self.server_args.tokenizer_worker_num > 1:
            # Modify rid, add worker_id
            if isinstance(obj.rid, list):
                # If it's an array, add worker_id prefix to each element
                obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid]
            else:
                # If it's a single value, add worker_id prefix
                obj.rid = f"{self.worker_id}_{obj.rid}"

387
388
        if self.enable_trace:
            self._trace_request_start(obj, created_time)
389

390
        if self.log_requests:
391
            max_length, skip_names, _ = self.log_request_metadata
392
            logger.info(
393
                f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
394
395
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
396
397
398
        async with self.is_pause_cond:
            await self.is_pause_cond.wait_for(lambda: not self.is_pause)

399
        async with self.model_update_lock.reader_lock:
400
401
402
403
            if self.server_args.enable_lora and obj.lora_path:
                # Look up the LoRA ID from the registry and start tracking ongoing LoRA requests.
                obj.lora_id = await self.lora_registry.acquire(obj.lora_path)

404
            if obj.is_single:
405
                tokenized_obj = await self._tokenize_one_request(obj)
406
407
                state = self._send_one_request(obj, tokenized_obj, created_time)
                async for response in self._wait_one_response(obj, state, request):
408
409
410
411
412
413
414
                    yield response
            else:
                async for response in self._handle_batch_request(
                    obj, request, created_time
                ):
                    yield response

415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
    def _detect_input_format(
        self, texts: Union[str, List[str]], is_cross_encoder: bool
    ) -> str:
        """Detect the format of input texts for proper tokenization handling.

        Returns:
            - "single_string": Regular single text like "Hello world"
            - "batch_strings": Regular batch like ["Hello", "World"]
            - "cross_encoder_pairs": Cross-encoder pairs like [["query", "document"]]
        """
        if isinstance(texts, str):
            return "single_string"

        if (
            is_cross_encoder
            and len(texts) > 0
            and isinstance(texts[0], list)
            and len(texts[0]) == 2
        ):
            return "cross_encoder_pairs"

        return "batch_strings"

    def _prepare_tokenizer_input(
        self, texts: Union[str, List[str]], input_format: str
    ) -> Union[List[str], List[List[str]]]:
        """Prepare input for the tokenizer based on detected format."""
        if input_format == "single_string":
            return [texts]  # Wrap single string for batch processing
        elif input_format == "cross_encoder_pairs":
            return texts  # Already in correct format: [["query", "doc"]]
        else:  # batch_strings
            return texts  # Already in correct format: ["text1", "text2"]

    def _extract_tokenizer_results(
        self,
        input_ids: List[List[int]],
        token_type_ids: Optional[List[List[int]]],
        input_format: str,
        original_batch_size: int,
    ) -> Union[
        Tuple[List[int], Optional[List[int]]],
        Tuple[List[List[int]], Optional[List[List[int]]]],
    ]:
        """Extract results from tokenizer output based on input format."""

        # For single inputs (string or single cross-encoder pair), extract first element
        if (
            input_format in ["single_string", "cross_encoder_pairs"]
            and original_batch_size == 1
        ):
            single_input_ids = input_ids[0] if input_ids else []
            single_token_type_ids = token_type_ids[0] if token_type_ids else None
            return single_input_ids, single_token_type_ids

        # For true batches, return as-is
        return input_ids, token_type_ids

    async def _tokenize_texts(
        self, texts: Union[str, List[str]], is_cross_encoder: bool = False
    ) -> Union[
        Tuple[List[int], Optional[List[int]]],
        Tuple[List[List[int]], Optional[List[List[int]]]],
    ]:
        """
        Tokenize text(s) using the appropriate tokenizer strategy.

        This method handles multiple input formats and chooses between async dynamic
        batch tokenizer (for single texts only) and regular tokenizer.

        Args:
            texts: Text input in various formats:

                   Regular cases:
                   - Single string: "How are you?"
                   - Batch of strings: ["Hello", "World", "How are you?"]

                   Cross-encoder cases (sentence pairs for similarity/ranking):
                   - Single pair: [["query text", "document text"]]
                   - Multiple pairs: [["q1", "d1"], ["q2", "d2"], ["q3", "d3"]]

            is_cross_encoder: Whether to return token_type_ids for cross-encoder models.
                             Enables proper handling of sentence pairs with segment IDs.

        Returns:
            Single input cases:
                Tuple[List[int], Optional[List[int]]]: (input_ids, token_type_ids)
                Example: ([101, 2129, 102], [0, 0, 0]) for single text
                Example: ([101, 2129, 102, 4068, 102], [0, 0, 0, 1, 1]) for cross-encoder pair

            Batch input cases:
                Tuple[List[List[int]], Optional[List[List[int]]]]: (batch_input_ids, batch_token_type_ids)
                Example: ([[101, 2129, 102], [101, 4068, 102]], None) for regular batch

            Note: token_type_ids is None unless is_cross_encoder=True.
        """
        if not texts or self.tokenizer is None:
            raise ValueError("texts cannot be empty and tokenizer must be initialized")

        # Step 1: Detect input format and prepare for tokenization
        input_format = self._detect_input_format(texts, is_cross_encoder)
        tokenizer_input = self._prepare_tokenizer_input(texts, input_format)
        original_batch_size = len(texts) if not isinstance(texts, str) else 1

        # Step 2: Set up tokenizer arguments
        tokenizer_kwargs = (
            {"return_token_type_ids": is_cross_encoder} if is_cross_encoder else {}
        )

        # Step 3: Choose tokenization strategy
        use_async_tokenizer = (
            self.async_dynamic_batch_tokenizer is not None
            and input_format == "single_string"
        )

        if use_async_tokenizer:
            logger.debug("Using async dynamic batch tokenizer for single text")
            result = await self.async_dynamic_batch_tokenizer.encode(
                tokenizer_input[0], **tokenizer_kwargs
            )
            # Convert to batch format for consistency
            input_ids = [result["input_ids"]]
            token_type_ids = (
                [result["token_type_ids"]]
                if is_cross_encoder and result.get("token_type_ids")
                else None
            )
        else:
            logger.debug(f"Using regular tokenizer for {len(tokenizer_input)} inputs")
            encoded = self.tokenizer(tokenizer_input, **tokenizer_kwargs)
            input_ids = encoded["input_ids"]
            token_type_ids = encoded.get("token_type_ids") if is_cross_encoder else None

        # Step 4: Extract results based on input format
        return self._extract_tokenizer_results(
            input_ids, token_type_ids, input_format, original_batch_size
        )

553
554
555
556
557
558
559
560
    async def _tokenize_one_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
    ):
        """Tokenize one request."""
        # Tokenize
        input_embeds = None
        input_text = obj.text
woodx's avatar
woodx committed
561
562
563
564
        token_type_ids = None
        is_cross_encoder_request = (
            isinstance(obj, EmbeddingReqInput) and obj.is_cross_encoder_request
        )
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
        if obj.input_embeds is not None:
            if not self.server_args.disable_radix_cache:
                raise ValueError(
                    "input_embeds is provided while disable_radix_cache is False. "
                    "Please add `--disable-radix-cache` when you launch the server "
                    "if you want to use input_embeds as inputs."
                )
            input_embeds = obj.input_embeds
            input_ids = obj.input_ids
        elif obj.input_ids is not None:
            input_ids = obj.input_ids
        else:
            if self.tokenizer is None:
                raise ValueError(
                    "The engine initialized with skip_tokenizer_init=True cannot "
                    "accept text prompts. Please provide input_ids or re-initialize "
                    "the engine with skip_tokenizer_init=False."
                )
woodx's avatar
woodx committed
583

584
585
586
            input_ids, token_type_ids = await self._tokenize_texts(
                input_text, is_cross_encoder_request
            )
587

588
        if self.mm_processor and obj.contains_mm_input():
589
590
591
592
593
            if not isinstance(obj.image_data, list):
                obj.image_data = [obj.image_data]
            if not isinstance(obj.audio_data, list):
                obj.audio_data = [obj.audio_data]
            mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
594
                image_data=obj.image_data,
595
                audio_data=obj.audio_data,
596
597
598
599
                input_text=input_text or input_ids,
                request_obj=obj,
                max_req_input_len=self.max_req_input_len,
            )
600
601
            if mm_inputs and "input_ids" in mm_inputs:
                input_ids = mm_inputs["input_ids"]
602
        else:
603
            mm_inputs = None
604

605
        self._validate_one_request(obj, input_ids)
606
        trace_slice_end("tokenize", obj.rid)
607
        return self._create_tokenized_object(
608
            obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
609
610
        )

611
    def _validate_one_request(
612
613
614
        self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
    ) -> None:
        """Validates that the input token count and the requested token count doesn't exceed the model's context length."""
615
        # FIXME: unify the length validation logic with the one in the scheduler.
616
        _max_req_len = self.context_len
617
618

        input_token_num = len(input_ids) if input_ids is not None else 0
619
        input_token_num += self.reserve_input_token_num
620
        if input_token_num >= self.context_len:
621
622
623
624
625
626
            if self.server_args.allow_auto_truncate:
                logger.warning(
                    f"The input ({input_token_num} tokens) is longer than the "
                    f"model's context length ({self.context_len} tokens). "
                    "Truncating the input."
                )
627
                del input_ids[_max_req_len:]
628
629
630
631
632
633
                input_token_num = len(input_ids)
            else:
                raise ValueError(
                    f"The input ({input_token_num} tokens) is longer than the "
                    f"model's context length ({self.context_len} tokens)."
                )
634

635
636
637
638
639
640
        if isinstance(obj, EmbeddingReqInput) and self.is_generation:
            raise ValueError(
                "This model does not appear to be an embedding model by default. "
                "Please add `--is-embedding` when launching the server or try another model."
            )

641
642
        # Check total tokens (input + max_new_tokens)
        max_new_tokens = obj.sampling_params.get("max_new_tokens")
643
        if (
644
            max_new_tokens is not None
645
            and (max_new_tokens + input_token_num) >= _max_req_len
646
        ):
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
            if self.server_args.allow_auto_truncate:
                logger.warning(
                    f"Requested token count ({input_token_num} input + {max_new_tokens} new) "
                    f"exceeds the model's context length ({self.context_len} tokens). "
                    "Truncating max_new_tokens."
                )
                obj.sampling_params["max_new_tokens"] = max(
                    0, _max_req_len - input_token_num
                )
            else:
                total_tokens = max_new_tokens + input_token_num
                error_msg = (
                    f"Requested token count exceeds the model's maximum context length "
                    f"of {self.context_len} tokens. You requested a total of {total_tokens} "
                    f"tokens: {input_token_num} tokens from the input messages and "
                    f"{max_new_tokens} tokens for the completion. Please reduce the number "
                    f"of tokens in the input messages or the completion to fit within the limit."
                )
                raise ValueError(error_msg)
666

667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
        if isinstance(obj, GenerateReqInput):
            if (
                obj.return_hidden_states
                and not self.server_args.enable_return_hidden_states
            ):
                raise ValueError(
                    "The server is not configured to return the hidden states. "
                    "Please set `--enable-return-hidden-states` to enable this feature."
                )
            if (
                obj.custom_logit_processor
                and not self.server_args.enable_custom_logit_processor
            ):
                raise ValueError(
                    "The server is not configured to enable custom logit processor. "
682
                    "Please set `--enable-custom-logit-processor` to enable this feature."
683
684
                )

685
686
687
688
689
690
691
692
    def _validate_input_ids_in_vocab(
        self, input_ids: List[int], vocab_size: int
    ) -> None:
        if any(id >= vocab_size for id in input_ids):
            raise ValueError(
                f"The input_ids {input_ids} contains values greater than the vocab size ({vocab_size})."
            )

693
694
695
696
697
698
    def _create_tokenized_object(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        input_text: str,
        input_ids: List[int],
        input_embeds: Optional[Union[List[float], None]] = None,
699
        mm_inputs: Optional[Dict] = None,
woodx's avatar
woodx committed
700
        token_type_ids: Optional[List[int]] = None,
701
702
    ) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
        """Create a tokenized request object from common parameters."""
703
704
705
706
707
708
709
710
        # Parse sampling parameters
        # Note: if there are preferred sampling params, we use them if they are not
        # explicitly passed in sampling_params
        if self.preferred_sampling_params:
            sampling_kwargs = {**self.preferred_sampling_params, **obj.sampling_params}
        else:
            sampling_kwargs = obj.sampling_params
        sampling_params = SamplingParams(**sampling_kwargs)
711
        sampling_params.normalize(self.tokenizer)
712
        sampling_params.verify(self.model_config.vocab_size)
713
714
715

        # Build return object
        if isinstance(obj, GenerateReqInput):
716
717
718
719
            session_params = (
                SessionParams(**obj.session_params) if obj.session_params else None
            )

720
721
722
            tokenized_obj = TokenizedGenerateReqInput(
                input_text,
                input_ids,
723
                mm_inputs,
724
                sampling_params,
725
726
727
728
                obj.return_logprob,
                obj.logprob_start_len,
                obj.top_logprobs_num,
                obj.token_ids_logprob,
729
                obj.stream,
730
                rid=obj.rid,
731
                bootstrap_host=obj.bootstrap_host,
732
                bootstrap_port=obj.bootstrap_port,
733
                bootstrap_room=obj.bootstrap_room,
734
                lora_id=obj.lora_id,
735
736
737
                input_embeds=input_embeds,
                session_params=session_params,
                custom_logit_processor=obj.custom_logit_processor,
738
                return_hidden_states=obj.return_hidden_states,
739
                data_parallel_rank=obj.data_parallel_rank,
740
                priority=obj.priority,
741
                extra_key=obj.extra_key,
742
743
744
745
746
            )
        elif isinstance(obj, EmbeddingReqInput):
            tokenized_obj = TokenizedEmbeddingReqInput(
                input_text,
                input_ids,
747
                mm_inputs,
woodx's avatar
woodx committed
748
                token_type_ids,
749
                sampling_params,
750
                rid=obj.rid,
751
                priority=obj.priority,
752
753
754
755
            )

        return tokenized_obj

756
757
758
759
760
761
762
763
764
765
    async def _batch_tokenize_and_process(
        self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
    ) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]:
        """Handle batch tokenization for text inputs only."""
        logger.debug(f"Starting batch tokenization for {batch_size} text requests")

        # Collect requests and texts
        requests = [obj[i] for i in range(batch_size)]
        texts = [req.text for req in requests]

766
767
768
769
770
771
772
773
774
775
        # Check if any request is a cross-encoder request
        is_cross_encoder_request = any(
            isinstance(req, EmbeddingReqInput) and req.is_cross_encoder_request
            for req in requests
        )

        # Batch tokenize all texts using unified method
        input_ids_list, token_type_ids_list = await self._tokenize_texts(
            texts, is_cross_encoder_request
        )
776
777
778
779

        # Process all requests
        tokenized_objs = []
        for i, req in enumerate(requests):
780
            self._validate_one_request(obj[i], input_ids_list[i])
781
782
783
            token_type_ids = (
                token_type_ids_list[i] if token_type_ids_list is not None else None
            )
784
785
            tokenized_objs.append(
                self._create_tokenized_object(
786
                    req, req.text, input_ids_list[i], None, None, token_type_ids
787
788
                )
            )
789
            trace_slice_end("tokenize", req.rid)
790
791
792
793
794
795
796
797
        logger.debug(f"Completed batch processing for {batch_size} requests")
        return tokenized_objs

    def _validate_batch_tokenization_constraints(
        self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
    ) -> None:
        """Validate constraints for batch tokenization processing."""
        for i in range(batch_size):
798
            if self.is_generation and obj[i].contains_mm_input():
799
                raise ValueError(
800
                    "For multimodal input processing do not set `enable_tokenizer_batch_encode`."
801
802
803
804
805
806
807
808
809
810
                )
            if obj[i].input_ids is not None:
                raise ValueError(
                    "Batch tokenization is not needed for pre-tokenized input_ids. Do not set `enable_tokenizer_batch_encode`."
                )
            if obj[i].input_embeds is not None:
                raise ValueError(
                    "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
                )

811
812
813
814
815
816
    def _send_one_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
        created_time: Optional[float] = None,
    ):
817
818
        trace_slice_start("dispatch", obj.rid)
        tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid)
Lianmin Zheng's avatar
Lianmin Zheng committed
819
        self.send_to_scheduler.send_pyobj(tokenized_obj)
820
        state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
821
        self.rid_to_state[obj.rid] = state
822
        trace_slice_end("dispatch", obj.rid, thread_finish_flag=True)
823
        return state
824

825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
    def _send_batch_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        tokenized_objs: List[
            Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]
        ],
        created_time: Optional[float] = None,
    ):
        """Send a batch of tokenized requests as a single batched request to the scheduler."""
        if isinstance(tokenized_objs[0], TokenizedGenerateReqInput):
            batch_req = BatchTokenizedGenerateReqInput(batch=tokenized_objs)
        else:
            batch_req = BatchTokenizedEmbeddingReqInput(batch=tokenized_objs)

        self.send_to_scheduler.send_pyobj(batch_req)

        # Create states for each individual request in the batch
        for i, tokenized_obj in enumerate(tokenized_objs):
            tmp_obj = obj[i]
            state = ReqState(
                [], False, asyncio.Event(), tmp_obj, created_time=created_time
            )
            self.rid_to_state[tmp_obj.rid] = state

849
850
851
    async def _wait_one_response(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
852
        state: ReqState,
853
854
855
856
857
858
859
        request: Optional[fastapi.Request] = None,
    ):
        """Wait for the response of one request."""
        while True:
            try:
                await asyncio.wait_for(state.event.wait(), timeout=4)
            except asyncio.TimeoutError:
860
861
862
863
864
                if (
                    request is not None
                    and not obj.background
                    and await request.is_disconnected()
                ):
Lianmin Zheng's avatar
Lianmin Zheng committed
865
                    # Abort the request for disconnected requests (non-streaming, waiting queue)
866
                    self.abort_request(obj.rid)
Lianmin Zheng's avatar
Lianmin Zheng committed
867
                    # Use exception to kill the whole call stack and asyncio task
868
                    raise ValueError(
Lianmin Zheng's avatar
Lianmin Zheng committed
869
                        f"Request is disconnected from the client side (type 1). Abort request {obj.rid=}"
870
                    )
871
872
873
874
875
876
877
                continue

            out = state.out_list[-1]

            state.out_list = []
            if state.finished:
                if self.log_requests:
878
879
880
881
882
                    max_length, skip_names, out_skip_names = self.log_request_metadata
                    if self.model_config.is_multimodal_gen:
                        msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
                    else:
                        msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
883
884
885
886
887
888
889
890
891
892
893
                    logger.info(msg)

                # Check if this was an abort/error created by scheduler
                if isinstance(out["meta_info"].get("finish_reason"), dict):
                    finish_reason = out["meta_info"]["finish_reason"]
                    if (
                        finish_reason.get("type") == "abort"
                        and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST
                    ):
                        raise ValueError(finish_reason["message"])

894
895
896
897
898
                    if finish_reason.get("type") == "abort" and finish_reason.get(
                        "status_code"
                    ) in (
                        HTTPStatus.SERVICE_UNAVAILABLE,
                        HTTPStatus.INTERNAL_SERVER_ERROR,
899
900
901
902
                    ):
                        # This is an abort request initiated by scheduler.
                        # Delete the key to prevent resending abort request to the scheduler and
                        # to ensure aborted request state is cleaned up.
903
904
                        if state.obj.rid in self.rid_to_state:
                            del self.rid_to_state[state.obj.rid]
905
906
907
908
909

                        # Mark ongoing LoRA request as finished.
                        if self.server_args.enable_lora and state.obj.lora_path:
                            await self.lora_registry.release(state.obj.lora_id)

910
911
912
913
                        raise fastapi.HTTPException(
                            status_code=finish_reason["status_code"],
                            detail=finish_reason["message"],
                        )
914
915
916
917
918
919
920
921
                yield out
                break

            state.event.clear()

            if obj.stream:
                yield out
            else:
922
923
924
925
926
                if (
                    request is not None
                    and not obj.background
                    and await request.is_disconnected()
                ):
Lianmin Zheng's avatar
Lianmin Zheng committed
927
                    # Abort the request for disconnected requests (non-streaming, running)
928
                    self.abort_request(obj.rid)
Lianmin Zheng's avatar
Lianmin Zheng committed
929
                    # Use exception to kill the whole call stack and asyncio task
930
                    raise ValueError(
Lianmin Zheng's avatar
Lianmin Zheng committed
931
                        f"Request is disconnected from the client side (type 3). Abort request {obj.rid=}"
932
                    )
933
934
935
936
937
938
939
940
941
942
943
944

    async def _handle_batch_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
        created_time: Optional[float] = None,
    ):
        batch_size = obj.batch_size

        generators = []
        rids = []
        if getattr(obj, "parallel_sample_num", 1) == 1:
945
946
947
948
949
950
            if self.server_args.enable_tokenizer_batch_encode:
                # Validate batch tokenization constraints
                self._validate_batch_tokenization_constraints(batch_size, obj)

                tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)

951
952
953
954
955
                # Send as a single batched request
                self._send_batch_request(obj, tokenized_objs, created_time)

                # Set up generators for each request in the batch
                for i in range(batch_size):
956
                    tmp_obj = obj[i]
957
958
959
960
961
                    generators.append(
                        self._wait_one_response(
                            tmp_obj, self.rid_to_state[tmp_obj.rid], request
                        )
                    )
962
963
964
                    rids.append(tmp_obj.rid)
            else:
                # Sequential tokenization and processing
fzyzcjy's avatar
fzyzcjy committed
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
                with (
                    input_blocker_guard_region(send_to_scheduler=self.send_to_scheduler)
                    if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
                    else nullcontext()
                ):
                    for i in range(batch_size):
                        tmp_obj = obj[i]
                        tokenized_obj = await self._tokenize_one_request(tmp_obj)
                        state = self._send_one_request(
                            tmp_obj, tokenized_obj, created_time
                        )
                        generators.append(
                            self._wait_one_response(tmp_obj, state, request)
                        )
                        rids.append(tmp_obj.rid)
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
        else:
            # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
            if batch_size > 128:
                logger.warning(
                    "Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
                    "The performance might be better if you just duplicate the requests n times or use "
                    "many threads to send them one by one with parallel sampling (n > 1)."
                )

            # Tokenize all requests
            objs = [obj[i] for i in range(batch_size)]
            tokenized_objs = await asyncio.gather(
                *(self._tokenize_one_request(obj) for obj in objs)
            )

            # Cache the common prefix for parallel sampling
            for i in range(batch_size):
                tmp_obj = copy.copy(objs[i])
                tokenized_obj = copy.copy(tokenized_objs[i])
                tokenized_obj.rid = tmp_obj.regenerate_rid()
                tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
                tokenized_obj.sampling_params.max_new_tokens = 0
                tokenized_obj.stream = False
1003
1004
                state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
                await self._wait_one_response(tmp_obj, state, request).__anext__()
1005
1006
1007
1008
1009
1010
1011

            # Expand requests, assign new rids for them, and send them
            for i in range(batch_size):
                for _ in range(obj.parallel_sample_num):
                    tmp_obj = copy.copy(objs[i])
                    tokenized_obj = copy.copy(tokenized_objs[i])
                    tokenized_obj.rid = tmp_obj.regenerate_rid()
1012
1013
                    state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, state, request))
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
                    rids.append(tmp_obj.rid)

        # Wait for all requests
        is_stream = hasattr(obj, "stream") and obj.stream
        if not is_stream:
            outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
            yield outputs
        else:
            rid_to_index = {rid: i for i, rid in enumerate(rids)}
            task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
            while task_map:
                done, _ = await asyncio.wait(
                    task_map.keys(), return_when=asyncio.FIRST_COMPLETED
                )

                for task in done:
                    gen = task_map.pop(task)
                    try:
                        result = task.result()
                        result["index"] = rid_to_index[result["meta_info"]["id"]]
                        yield result
                        new_task = asyncio.create_task(gen.__anext__())
                        task_map[new_task] = gen
                    except StopAsyncIteration:
                        pass
1039

1040
1041
    def abort_request(self, rid: str = "", abort_all: bool = False):
        if not abort_all and rid not in self.rid_to_state:
1042
            return
1043
        req = AbortReq(rid=rid, abort_all=abort_all)
1044
        self.send_to_scheduler.send_pyobj(req)
1045
        if self.enable_metrics:
1046
1047
1048
1049
            # TODO: also use custom_labels from the request
            self.metrics_collector.observe_one_aborted_request(
                self.metrics_collector.labels
            )
1050

1051
    async def pause_generation(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
1052
1053
        async with self.is_pause_cond:
            self.is_pause = True
1054
1055
1056
            self.abort_request(abort_all=True)

    async def continue_generation(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
1057
1058
1059
        async with self.is_pause_cond:
            self.is_pause = False
            self.is_pause_cond.notify_all()
1060

Chayenne's avatar
Chayenne committed
1061
1062
1063
1064
    async def update_weights_from_disk(
        self,
        obj: UpdateWeightFromDiskReqInput,
        request: Optional[fastapi.Request] = None,
1065
    ) -> Tuple[bool, str]:
1066
        self.auto_create_handle_loop()
1067
1068
1069
1070

        # default the load format to the server_args
        if obj.load_format is None:
            obj.load_format = self.server_args.load_format
1071
        logger.info("Start update_weights. Load format=%s", obj.load_format)
1072

1073
1074
1075
        if obj.abort_all_requests:
            self.abort_request(abort_all=True)

1076
        if True:  # Keep this redundant check to simplify some internal code sync
1077
1078
1079
1080
            # Hold the lock if it is not async. This means that weight sync
            # cannot run while requests are in progress.
            async with self.model_update_lock.writer_lock:
                return await self._wait_for_model_update_from_disk(obj)
1081

1082
1083
    async def _wait_for_model_update_from_disk(
        self, obj: UpdateWeightFromDiskReqInput
1084
    ) -> Tuple[bool, str]:
1085
        if self.server_args.tokenizer_worker_num > 1:
1086
            obj = MultiTokenizerWrapper(self.worker_id, obj)
1087
1088
1089
1090
1091
1092
1093
1094
1095
        self.send_to_scheduler.send_pyobj(obj)
        self.model_update_result = asyncio.Future()
        if self.server_args.dp_size == 1:
            result = await self.model_update_result
            if result.success:
                self.served_model_name = obj.model_path
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
1096
            return result.success, result.message, result.num_paused_requests
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
        else:  # self.server_args.dp_size > 1
            self.model_update_tmp = []
            result = await self.model_update_result

            all_success = all([r.success for r in result])
            if all_success is True:
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
            all_message = [r.message for r in result]
            all_message = " | ".join(all_message)
1108
1109
            all_paused_requests = [r.num_paused_requests for r in result]
            return all_success, all_message, all_paused_requests
1110

1111
    def configure_logging(self, obj: ConfigureLoggingReq):
1112
1113
1114
1115
1116
1117
1118
1119
        if obj.log_requests is not None:
            self.log_requests = obj.log_requests
        if obj.log_requests_level is not None:
            self.log_requests_level = obj.log_requests_level
        if obj.dump_requests_folder is not None:
            self.dump_requests_folder = obj.dump_requests_folder
        if obj.dump_requests_threshold is not None:
            self.dump_requests_threshold = obj.dump_requests_threshold
1120
1121
        if obj.crash_dump_folder is not None:
            self.crash_dump_folder = obj.crash_dump_folder
1122
        logging.info(f"Config logging: {obj=}")
1123
        self.log_request_metadata = self.get_log_request_metadata()
1124

1125
1126
1127
1128
1129
1130
    async def freeze_gc(self):
        """Send a freeze_gc message to the scheduler first, then freeze locally."""
        self.send_to_scheduler.send_pyobj(FreezeGCReq())
        freeze_gc("Tokenizer Manager")
        return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1131
    def create_abort_task(self, obj: GenerateReqInput):
1132
1133
        # Abort the request if the client is disconnected.
        async def abort_request():
Lianmin Zheng's avatar
Lianmin Zheng committed
1134
            await asyncio.sleep(2)
1135
1136
1137
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
1138
                for rid in obj.rid:
1139
1140
1141
1142
1143
1144
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

1145
    def auto_create_handle_loop(self):
1146
        if self.no_create_loop:
1147
1148
            return

1149
        self.no_create_loop = True
Lianmin Zheng's avatar
Lianmin Zheng committed
1150
        loop = asyncio.get_event_loop()
1151
1152
1153
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.handle_loop))
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1154

1155
1156
        self.event_loop = loop

1157
1158
1159
1160
        # We cannot add signal handler when the tokenizer manager is not in
        # the main thread due to the CPython limitation.
        if threading.current_thread() is threading.main_thread():
            signal_handler = SignalHandler(self)
1161
1162
1163
1164
1165
            loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
            # Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
            loop.add_signal_handler(
                signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
            )
1166
1167
1168
1169
1170
1171
        else:
            logger.warning(
                "Signal handler is not added because the tokenizer manager is "
                "not in the main thread. This disables graceful shutdown of the "
                "tokenizer manager when SIGTERM is received."
            )
1172
1173
1174
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
        )
1175
1176
1177
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.watch_load_thread))
        )
1178

1179
1180
1181
1182
1183
1184
    def dump_requests_before_crash(self):
        if self.crash_dump_performed:
            logger.info(
                "SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
            )
            return
1185

1186
1187
1188
        if not self.crash_dump_folder:
            return

1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
        logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
        self.crash_dump_performed = True

        # Check if NFS directory is available
        # expected_nfs_dir = "/" + self.crash_dump_folder.lstrip("/").split("/")[0]
        # use_nfs_dir = os.path.isdir(expected_nfs_dir) and os.access(
        #     expected_nfs_dir, os.W_OK
        # )
        use_nfs_dir = False
        if not use_nfs_dir:
            logger.error(
                f"Expected NFS directory is not available or writable. Uploading to GCS."
            )

1203
1204
1205
1206
1207
1208
1209
1210
1211
        data_to_dump = []
        if self.crash_dump_request_list:
            data_to_dump.extend(self.crash_dump_request_list)

        # Add unfinished requests from rid_to_state
        unfinished_requests = []
        for rid, state in self.rid_to_state.items():
            if not state.finished:
                unfinished_requests.append(
1212
1213
1214
1215
1216
1217
                    (
                        state.obj,
                        state.out_list[-1] if state.out_list else {},
                        state.created_time,
                        time.time(),
                    )
1218
1219
1220
1221
1222
1223
1224
                )
        if unfinished_requests:
            data_to_dump.extend(unfinished_requests)

        if not data_to_dump:
            return

1225
        object_name = f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl'
1226
1227
1228
        filename = os.path.join(
            self.crash_dump_folder,
            os.getenv("HOSTNAME", None),
1229
            object_name,
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
        )

        os.makedirs(os.path.dirname(filename), exist_ok=True)
        # Include server_args in the dump
        data_to_dump_with_server_args = {
            "server_args": self.server_args,
            "requests": data_to_dump,
        }
        with open(filename, "wb") as f:
            pickle.dump(data_to_dump_with_server_args, f)
        logger.error(
            f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}"
        )

1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
        def _upload_file_to_gcs(bucket_name, source_file_path, object_name):
            from google.cloud import storage

            client = storage.Client()
            bucket = client.bucket(bucket_name)
            blob = bucket.blob(object_name)
            blob.upload_from_filename(source_file_path, if_generation_match=0)
            logger.error(
                f"Successfully uploaded {source_file_path} to gs://{bucket_name}/{object_name}"
            )

        if not use_nfs_dir:
            _upload_file_to_gcs(
                "sglang_crash_dump",
                filename,
                os.getenv("HOSTNAME", None) + "/" + object_name,
            )

1262
1263
    async def sigterm_watchdog(self):
        while not self.gracefully_exit:
1264
            await asyncio.sleep(5)
1265

1266
        # Drain requests
1267
        while True:
1268
            remain_num_req = len(self.rid_to_state)
1269
            remaining_rids = list(self.rid_to_state.keys())
1270

Lianmin Zheng's avatar
Lianmin Zheng committed
1271
            if self.server_status == ServerStatus.UnHealthy:
1272
                # if health check failed, we should exit immediately
1273
                logger.error(
1274
                    "Signal SIGTERM received while health check failed. Force exiting."
1275
                )
1276
                self.dump_requests_before_crash()
1277
                break
1278
1279
1280
1281

            elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
                # if force shutdown flag set, exit immediately
                logger.error(
1282
                    "Signal SIGTERM received while force shutdown flag set. Force exiting."
1283
1284
                )
                break
1285

1286
            logger.info(
1287
                f"Gracefully exiting... Remaining number of requests {remain_num_req}. Remaining requests {remaining_rids=}."
1288
1289
1290
1291
            )
            if remain_num_req > 0:
                await asyncio.sleep(5)
            else:
1292
                self.dump_requests_before_crash()
1293
1294
                break

1295
        kill_process_tree(os.getpid(), include_parent=True)
1296
        sys.exit(0)
1297

Lianmin Zheng's avatar
Lianmin Zheng committed
1298
    async def handle_loop(self):
1299
        """The event loop that handles requests"""
Lianmin Zheng's avatar
Lianmin Zheng committed
1300
        while True:
1301
            recv_obj = await self.recv_from_detokenizer.recv_pyobj()
1302
            self._result_dispatcher(recv_obj)
1303
            self.last_receive_tstamp = time.time()
1304

1305
    def _handle_batch_output(
1306
1307
        self,
        recv_obj: Union[
1308
1309
1310
1311
            BatchStrOutput,
            BatchEmbeddingOutput,
            BatchMultimodalOutput,
            BatchTokenIDOutput,
1312
        ],
1313
1314
1315
1316
    ):
        for i, rid in enumerate(recv_obj.rids):
            state = self.rid_to_state.get(rid, None)
            if state is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1317
1318
1319
                logger.error(
                    f"Received output for {rid=} but the state was deleted in TokenizerManager."
                )
1320
                continue
1321

1322
1323
1324
            origin_rid = rid
            if self.server_args.tokenizer_worker_num > 1:
                origin_rid = get_origin_rid(rid)
1325
            # Build meta_info and return value
1326
            meta_info = {
1327
                "id": origin_rid,
1328
1329
                "finish_reason": recv_obj.finished_reasons[i],
                "prompt_tokens": recv_obj.prompt_tokens[i],
1330
                "weight_version": self.server_args.weight_version,
1331
1332
1333
1334
1335
            }

            if getattr(state.obj, "return_logprob", False):
                self.convert_logprob_style(
                    meta_info,
1336
                    state,
1337
                    state.obj.top_logprobs_num,
1338
                    state.obj.token_ids_logprob,
1339
1340
                    state.obj.return_text_in_logprobs
                    and not self.server_args.skip_tokenizer_init,
1341
1342
1343
1344
                    recv_obj,
                    i,
                )

1345
            if not isinstance(recv_obj, BatchEmbeddingOutput):
1346
1347
1348
1349
1350
1351
1352
                meta_info.update(
                    {
                        "completion_tokens": recv_obj.completion_tokens[i],
                        "cached_tokens": recv_obj.cached_tokens[i],
                    }
                )

1353
            if getattr(recv_obj, "output_hidden_states", None):
1354
1355
                meta_info["hidden_states"] = recv_obj.output_hidden_states[i]

1356
            if isinstance(recv_obj, BatchStrOutput):
1357
                state.text += recv_obj.output_strs[i]
1358
1359
1360
1361
1362
1363
1364
1365
                if state.obj.stream:
                    state.output_ids.extend(recv_obj.output_ids[i])
                    output_token_ids = state.output_ids[state.last_output_offset :]
                    state.last_output_offset = len(state.output_ids)
                else:
                    state.output_ids.extend(recv_obj.output_ids[i])
                    output_token_ids = state.output_ids.copy()

1366
                out_dict = {
1367
                    "text": state.text,
1368
                    "output_ids": output_token_ids,
1369
1370
                    "meta_info": meta_info,
                }
1371
            elif isinstance(recv_obj, BatchTokenIDOutput):
1372
                if self.server_args.stream_output and state.obj.stream:
1373
1374
1375
                    state.output_ids.extend(recv_obj.output_ids[i])
                    output_token_ids = state.output_ids[state.last_output_offset :]
                    state.last_output_offset = len(state.output_ids)
1376
                else:
1377
                    state.output_ids.extend(recv_obj.output_ids[i])
1378
                    output_token_ids = state.output_ids.copy()
1379

1380
                out_dict = {
1381
                    "output_ids": output_token_ids,
1382
1383
                    "meta_info": meta_info,
                }
1384
            elif isinstance(recv_obj, BatchMultimodalOutput):
1385
                raise NotImplementedError("BatchMultimodalOut not implemented")
1386
            else:
1387
                assert isinstance(recv_obj, BatchEmbeddingOutput)
1388
1389
1390
1391
1392
1393
                out_dict = {
                    "embedding": recv_obj.embeddings[i],
                    "meta_info": meta_info,
                }

            state.finished = recv_obj.finished_reasons[i] is not None
1394
1395
1396
1397
1398
            if state.finished:
                if self.server_args.speculative_algorithm:
                    meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
                state.finished_time = time.time()
                meta_info["e2e_latency"] = state.finished_time - state.created_time
1399
1400
1401

                trace_req_finish(rid, ts=int(state.finished_time * 1e9))

Lianmin Zheng's avatar
Lianmin Zheng committed
1402
                del self.rid_to_state[rid]
1403

1404
1405
1406
1407
                # Mark ongoing LoRA request as finished.
                if self.server_args.enable_lora and state.obj.lora_path:
                    asyncio.create_task(self.lora_registry.release(state.obj.lora_id))

1408
            state.out_list.append(out_dict)
1409
1410
            state.event.set()

1411
            # Log metrics and dump
1412
1413
1414
1415
            if self.enable_metrics and state.obj.log_metrics:
                self.collect_metrics(state, recv_obj, i)
            if self.dump_requests_folder and state.finished and state.obj.log_metrics:
                self.dump_requests(state, out_dict)
1416
1417
            if self.crash_dump_folder and state.finished and state.obj.log_metrics:
                self.record_request_for_crash_dump(state, out_dict)
1418
1419
1420
1421

    def convert_logprob_style(
        self,
        meta_info: dict,
1422
        state: ReqState,
1423
        top_logprobs_num: int,
1424
        token_ids_logprob: List[int],
1425
        return_text_in_logprobs: bool,
1426
        recv_obj: BatchStrOutput,
1427
1428
        recv_obj_index: int,
    ):
1429
1430
1431
        if recv_obj.input_token_logprobs_val is None:
            return

1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
        if len(recv_obj.input_token_logprobs_val) > 0:
            state.input_token_logprobs_val.extend(
                recv_obj.input_token_logprobs_val[recv_obj_index]
            )
            state.input_token_logprobs_idx.extend(
                recv_obj.input_token_logprobs_idx[recv_obj_index]
            )
        state.output_token_logprobs_val.extend(
            recv_obj.output_token_logprobs_val[recv_obj_index]
        )
        state.output_token_logprobs_idx.extend(
            recv_obj.output_token_logprobs_idx[recv_obj_index]
        )
1445
        meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
1446
1447
            state.input_token_logprobs_val,
            state.input_token_logprobs_idx,
1448
1449
1450
            return_text_in_logprobs,
        )
        meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
1451
1452
            state.output_token_logprobs_val,
            state.output_token_logprobs_idx,
1453
1454
1455
1456
            return_text_in_logprobs,
        )

        if top_logprobs_num > 0:
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
            if len(recv_obj.input_top_logprobs_val) > 0:
                state.input_top_logprobs_val.extend(
                    recv_obj.input_top_logprobs_val[recv_obj_index]
                )
                state.input_top_logprobs_idx.extend(
                    recv_obj.input_top_logprobs_idx[recv_obj_index]
                )
            state.output_top_logprobs_val.extend(
                recv_obj.output_top_logprobs_val[recv_obj_index]
            )
            state.output_top_logprobs_idx.extend(
                recv_obj.output_top_logprobs_idx[recv_obj_index]
            )
1470
            meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1471
1472
                state.input_top_logprobs_val,
                state.input_top_logprobs_idx,
1473
1474
1475
                return_text_in_logprobs,
            )
            meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1476
1477
                state.output_top_logprobs_val,
                state.output_top_logprobs_idx,
1478
1479
1480
                return_text_in_logprobs,
            )

1481
        if token_ids_logprob is not None:
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
            if len(recv_obj.input_token_ids_logprobs_val) > 0:
                state.input_token_ids_logprobs_val.extend(
                    recv_obj.input_token_ids_logprobs_val[recv_obj_index]
                )
                state.input_token_ids_logprobs_idx.extend(
                    recv_obj.input_token_ids_logprobs_idx[recv_obj_index]
                )
            state.output_token_ids_logprobs_val.extend(
                recv_obj.output_token_ids_logprobs_val[recv_obj_index]
            )
            state.output_token_ids_logprobs_idx.extend(
                recv_obj.output_token_ids_logprobs_idx[recv_obj_index]
            )
1495
            meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
1496
1497
                state.input_token_ids_logprobs_val,
                state.input_token_ids_logprobs_idx,
1498
1499
1500
1501
                return_text_in_logprobs,
            )
            meta_info["output_token_ids_logprobs"] = (
                self.detokenize_top_logprobs_tokens(
1502
1503
                    state.output_token_ids_logprobs_val,
                    state.output_token_ids_logprobs_idx,
1504
1505
1506
1507
                    return_text_in_logprobs,
                )
            )

1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
    def detokenize_logprob_tokens(
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
    ):
        if not decode_to_text:
            return [
                (logprob, token_id, None)
                for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
            ]
        else:
            assert self.tokenizer is not None
            token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
            return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))

    def detokenize_top_logprobs_tokens(
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
    ):
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
        ret = []
        for i in range(len(token_logprobs_val)):
            if token_logprobs_val[i]:
                ret.append(
                    self.detokenize_logprob_tokens(
                        token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
                    )
                )
            else:
                ret.append(None)
        return ret

1544
    def collect_metrics(self, state: ReqState, recv_obj: BatchStrOutput, i: int):
1545
1546
1547
1548
1549
1550
        completion_tokens = (
            recv_obj.completion_tokens[i]
            if getattr(recv_obj, "completion_tokens", None)
            else 0
        )

1551
        custom_labels = getattr(state.obj, "custom_labels", None)
1552
        labels = (
1553
1554
            {**self.metrics_collector.labels, **custom_labels}
            if custom_labels
1555
1556
            else self.metrics_collector.labels
        )
1557
1558
1559
1560
        if (
            state.first_token_time == 0.0
            and self.disaggregation_mode != DisaggregationMode.PREFILL
        ):
1561
1562
            state.first_token_time = state.last_time = time.time()
            state.last_completion_tokens = completion_tokens
1563
            self.metrics_collector.observe_time_to_first_token(
1564
                labels, state.first_token_time - state.created_time
1565
1566
            )
        else:
1567
1568
1569
1570
1571
            num_new_tokens = completion_tokens - state.last_completion_tokens
            if num_new_tokens:
                new_time = time.time()
                interval = new_time - state.last_time
                self.metrics_collector.observe_inter_token_latency(
1572
                    labels,
1573
1574
                    interval,
                    num_new_tokens,
1575
                )
1576
1577
                state.last_time = new_time
                state.last_completion_tokens = completion_tokens
1578
1579

        if state.finished:
1580
1581
1582
1583
1584
1585
            has_grammar = (
                state.obj.sampling_params.get("json_schema", None)
                or state.obj.sampling_params.get("regex", None)
                or state.obj.sampling_params.get("ebnf", None)
                or state.obj.sampling_params.get("structural_tag", None)
            )
1586
            self.metrics_collector.observe_one_finished_request(
1587
                labels,
1588
1589
                recv_obj.prompt_tokens[i],
                completion_tokens,
1590
                recv_obj.cached_tokens[i],
1591
                state.finished_time - state.created_time,
1592
                has_grammar,
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
            )

    def dump_requests(self, state: ReqState, out_dict: dict):
        self.dump_request_list.append(
            (state.obj, out_dict, state.created_time, time.time())
        )

        if len(self.dump_request_list) >= self.dump_requests_threshold:
            filename = os.path.join(
                self.dump_requests_folder,
                datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
            )
1605
1606
1607
1608
1609
            self._dump_data_to_file(
                data_list=self.dump_request_list,
                filename=filename,
                log_message=f"Dump {len(self.dump_request_list)} requests to {filename}",
            )
1610
1611
            self.dump_request_list = []

1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
    def record_request_for_crash_dump(self, state: ReqState, out_dict: dict):
        current_time = time.time()
        self.crash_dump_request_list.append(
            (state.obj, out_dict, state.created_time, current_time)
        )
        # Remove requests older than 5 minutes based on finish time
        while (
            self.crash_dump_request_list
            and current_time - self.crash_dump_request_list[0][3] >= 300
        ):
            self.crash_dump_request_list.popleft()

1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
    def _dump_data_to_file(
        self, data_list: List[Tuple], filename: str, log_message: str
    ):
        logger.info(log_message)
        to_dump_with_server_args = {
            "server_args": self.server_args,
            "requests": data_list.copy(),
        }

        def background_task():
            os.makedirs(os.path.dirname(filename), exist_ok=True)
            with open(filename, "wb") as f:
                pickle.dump(to_dump_with_server_args, f)

        asyncio.create_task(asyncio.to_thread(background_task))

1640
    def _handle_abort_req(self, recv_obj: AbortReq):
1641
1642
        if is_health_check_generate_req(recv_obj):
            return
1643
        state = self.rid_to_state[recv_obj.rid]
1644
1645
1646
        origin_rid = recv_obj.rid
        if self.server_args.tokenizer_worker_num > 1:
            origin_rid = get_origin_rid(origin_rid)
1647
        state.finished = True
1648
1649
1650
1651
1652
1653
1654
1655
1656
        if recv_obj.finished_reason:
            out = {
                "meta_info": {
                    "id": recv_obj.rid,
                    "finish_reason": recv_obj.finished_reason,
                },
            }
        else:
            out = {
1657
1658
                "text": "",
                "meta_info": {
1659
                    "id": origin_rid,
1660
1661
1662
1663
1664
1665
1666
1667
                    "finish_reason": {
                        "type": "abort",
                        "message": "Abort before prefill",
                    },
                    "prompt_tokens": 0,
                    "completion_tokens": 0,
                },
            }
1668
        state.out_list.append(out)
1669
        state.event.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
1670

1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
    def _handle_open_session_req_output(self, recv_obj):
        self.session_futures[recv_obj.session_id].set_result(
            recv_obj.session_id if recv_obj.success else None
        )

    def _handle_update_weights_from_disk_req_output(self, recv_obj):
        if self.server_args.dp_size == 1:
            self.model_update_result.set_result(recv_obj)
        else:  # self.server_args.dp_size > 1
            self.model_update_tmp.append(recv_obj)
1681
            # set future if the all results are received
1682
1683
1684
            if len(self.model_update_tmp) == self.server_args.dp_size:
                self.model_update_result.set_result(self.model_update_tmp)

1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
    def _initialize_multi_item_delimiter_text(self):
        """Initialize multi-item delimiter text from token ID after tokenizer is loaded."""
        if (
            hasattr(self.server_args, "multi_item_scoring_delimiter")
            and self.server_args.multi_item_scoring_delimiter is not None
            and self.tokenizer is not None
        ):
            try:
                self.multi_item_delimiter_text = self.tokenizer.decode(
                    [self.server_args.multi_item_scoring_delimiter],
                    skip_special_tokens=False,
                )
            except Exception as e:
                logger.warning(
                    f"Failed to decode delimiter token {self.server_args.multi_item_scoring_delimiter}: {e}"
                )
                self.multi_item_delimiter_text = None

    def _build_multi_item_token_sequence(
        self, query: List[int], items: List[List[int]], delimiter_token_id: int
    ) -> List[int]:
        """
        Build a single token sequence for multi-item scoring.
        Format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>

        Args:
            query: Query token IDs
            items: List of item token ID sequences
            delimiter_token_id: Token ID to use as delimiter

        Returns:
            Combined token sequence
        """
        combined_sequence = query[:]  # Start with query

        for item in items:
            combined_sequence.append(delimiter_token_id)  # Add delimiter
            combined_sequence.extend(item)  # Add item tokens

        # Add final delimiter after the last item for logprob extraction
        combined_sequence.append(delimiter_token_id)

        return combined_sequence

    def _extract_logprobs_for_tokens(
        self, logprobs_data: List, label_token_ids: List[int]
    ) -> Dict[int, float]:
        """
        Extract logprobs for specified token IDs from logprobs data.

        Args:
            logprobs_data: List of (logprob, token_id, text) tuples
            label_token_ids: Token IDs to extract logprobs for

        Returns:
            Dictionary mapping token_id to logprob
        """
        logprobs = {}
        if logprobs_data:
            for logprob, token_id, _ in logprobs_data:
                if token_id in label_token_ids:
                    logprobs[token_id] = logprob
        return logprobs

    def _convert_logprobs_to_scores(
        self,
        logprobs: Dict[int, float],
        label_token_ids: List[int],
        apply_softmax: bool,
    ) -> List[float]:
        """
        Convert logprobs dictionary to ordered score list.

        Args:
            logprobs: Dictionary mapping token_id to logprob
            label_token_ids: Token IDs in desired order
            apply_softmax: Whether to apply softmax normalization

        Returns:
            List of scores in the same order as label_token_ids
        """
        score_list = [
            logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
        ]

        if apply_softmax:
            score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
        else:
            # Convert logprobs to probabilities if not using softmax
            score_list = [
                math.exp(x) if x != float("-inf") else 0.0 for x in score_list
            ]

        return score_list

    def _process_multi_item_scoring_results(
        self,
        results: Any,
        items: List,
        label_token_ids: List[int],
        apply_softmax: bool,
        batch_request=None,
    ) -> List[List[float]]:
        """
        Process results from multi-item scoring request.
        Extracts logprobs at delimiter positions from input_token_ids_logprobs.

        Args:
            results: Results from generate_request
            items: List of items being scored
            label_token_ids: Token IDs to extract scores for
            apply_softmax: Whether to apply softmax normalization
            batch_request: The original batch request containing input sequence

        Returns:
            List of score lists, one for each item
        """
        single_result = results[0] if isinstance(results, list) else results

        # For multi-item scoring, logprobs are in input_token_ids_logprobs
        input_logprobs = single_result["meta_info"].get("input_token_ids_logprobs", [])

        if not input_logprobs:
            raise RuntimeError(
                f"input_token_ids_logprobs is empty for multi-item scoring request {single_result['meta_info'].get('id', '<unknown>')}. "
                "This indicates token_ids_logprobs were not computed properly for Mutil Item Scoring."
            )

        scores = []
        num_items = len(items) if isinstance(items, list) else 1

        # Check if we have the expected number of logprobs
        expected_logprobs_count = num_items + 1
        if len(input_logprobs) != expected_logprobs_count:
            raise RuntimeError(
                f"Expected {expected_logprobs_count} input_token_ids_logprobs for multi-item scoring "
                f"with {num_items} items, but got {len(input_logprobs)}. "
                f"Request ID: {single_result['meta_info'].get('id', '<unknown>')}"
            )

        # Skip the first delimiter (between query and first item) and process remaining delimiter positions
        # We want to exclude the first one since it represents the boundary between query and first item, not an item boundary
        start_idx = 1 if len(input_logprobs) > 1 else 0

        # Process logprobs for each item position (excluding first delimiter)
        for item_idx in range(num_items):
            logprob_idx = start_idx + item_idx
            item_logprobs_data = input_logprobs[logprob_idx]
            logprobs = self._extract_logprobs_for_tokens(
                item_logprobs_data, label_token_ids
            )
            score_list = self._convert_logprobs_to_scores(
                logprobs, label_token_ids, apply_softmax
            )
            scores.append(score_list)

        return scores

    def _process_single_item_scoring_results(
        self, results: Any, label_token_ids: List[int], apply_softmax: bool
    ) -> List[List[float]]:
        """
        Process results from single-item scoring request.
        Single-item scoring results are stored in output_token_ids_logprobs.

        Args:
            results: Results from generate_request
            label_token_ids: Token IDs to extract scores for
            apply_softmax: Whether to apply softmax normalization

        Returns:
            List of score lists, one for each result
        """
        scores = []

        for result in results:
            # For single-item scoring, logprobs are in output_token_ids_logprobs
            output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])

            if not output_logprobs or len(output_logprobs) == 0:
                raise RuntimeError(
                    f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}."
                )

            # Extract logprobs for the first (and only) position
            logprobs = self._extract_logprobs_for_tokens(
                output_logprobs[0], label_token_ids
            )
            score_list = self._convert_logprobs_to_scores(
                logprobs, label_token_ids, apply_softmax
            )
            scores.append(score_list)

        return scores

1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
    async def score_request(
        self,
        query: Optional[Union[str, List[int]]] = None,
        items: Optional[Union[str, List[str], List[List[int]]]] = None,
        label_token_ids: Optional[List[int]] = None,
        apply_softmax: bool = False,
        item_first: bool = False,
        request: Optional[Any] = None,
    ) -> List[List[float]]:
        """
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
        Score the probability of specified token IDs appearing after the given (query + item) pair.

        This method supports two scoring approaches:
        1. Single-Item scoring (default): Process each query+item pair independently
        2. Multi-Item scoring: When multi_item_scoring_delimiter is set, combine query and
           multiple items into a single sequence using delimiter for efficient processing.
           Note: item_first parameter is ignored in multi-item scoring mode since it uses
           a fixed format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>

           Multi-item scoring works with both text and pre-tokenized inputs:
           - Text: query<delimiter_text>item1<delimiter_text>item2<delimiter_text>item3<delimiter_text>
           - Tokens: query<delimiter_token_id>item1<delimiter_token_id>item2<delimiter_token_id>item3<delimiter_token_id>

        Args:
            query: The query text or pre-tokenized query token IDs
            items: The item text(s) or pre-tokenized item token IDs
            label_token_ids: List of token IDs to compute probabilities for
            apply_softmax: Whether to normalize probabilities using softmax
            item_first: If True, prepend items to query. Ignored for multi-item scoring.
            request: Optional FastAPI request object

        Returns:
            List of lists containing probabilities for each item and each label token
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
        """
        if label_token_ids is None:
            raise ValueError("label_token_ids must be provided")

        if self.tokenizer is not None:
            vocab_size = self.tokenizer.vocab_size
            for token_id in label_token_ids:
                if token_id >= vocab_size:
                    raise ValueError(
                        f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
                    )

1925
1926
1927
1928
1929
1930
        # Check if multi-item scoring is enabled by presence of delimiter
        use_multi_item_scoring = (
            self.server_args.multi_item_scoring_delimiter is not None
            and self.multi_item_delimiter_text is not None
        )

1931
1932
1933
        batch_request = GenerateReqInput(
            token_ids_logprob=label_token_ids,
            return_logprob=True,
1934
1935
            # Set logprob_start_len=0 for multi-item scoring since we want logprobs at all delimiter positions
            logprob_start_len=0 if use_multi_item_scoring else -1,
1936
1937
1938
1939
            stream=False,
            sampling_params={"max_new_tokens": 0},
        )

1940
1941
1942
1943
1944
1945
1946
        # Handle string or tokenized query/items
        if isinstance(query, str) and (
            isinstance(items, str)
            or (isinstance(items, list) and (not items or isinstance(items[0], str)))
        ):
            # Both query and items are text
            items_list = [items] if isinstance(items, str) else items
1947

1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
            if use_multi_item_scoring:
                # Multi-item scoring: create single prompt with delimiter text
                # Always use format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
                # (item_first is ignored for multi-item scoring)
                delimiter = self.multi_item_delimiter_text
                combined_items = delimiter.join(items_list)
                # Add final delimiter after the last item for logprob extraction
                single_prompt = f"{query}{delimiter}{combined_items}{delimiter}"
                batch_request.text = [single_prompt]
            else:
                # Single-item scoring: create separate prompts for each item
                if item_first:
                    prompts = [f"{item}{query}" for item in items_list]
                else:
                    prompts = [f"{query}{item}" for item in items_list]
                batch_request.text = prompts
1964

1965
1966
1967
1968
1969
1970
1971
        elif (
            isinstance(query, list)
            and isinstance(items, list)
            and items
            and isinstance(items[0], list)
        ):
            # Both query and items are token IDs
1972
1973
1974
1975
1976
1977
1978
1979
            if use_multi_item_scoring:
                # Multi-item scoring: concatenate with delimiter token ID
                # Format: query<delimiter_token_id>item1<delimiter_token_id>item2<delimiter_token_id>item3<delimiter_token_id>
                delimiter_token_id = self.server_args.multi_item_scoring_delimiter
                combined_input_ids = self._build_multi_item_token_sequence(
                    query, items, delimiter_token_id
                )
                batch_request.input_ids = [combined_input_ids]
1980
            else:
1981
1982
1983
1984
1985
1986
                # Single-item scoring: process each item separately
                if item_first:
                    input_ids_list = [item + query for item in items]
                else:
                    input_ids_list = [query + item for item in items]
                batch_request.input_ids = input_ids_list
1987
1988
1989
1990
1991
1992
1993
        else:
            raise ValueError(
                "Invalid combination of query/items types for score_request."
            )

        results = await self.generate_request(batch_request, request).__anext__()

1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
        if use_multi_item_scoring:
            # Multi-item scoring: extract scores from input_token_ids_logprobs
            return self._process_multi_item_scoring_results(
                results, items, label_token_ids, apply_softmax, batch_request
            )
        else:
            # Single-item scoring: process each result separately
            return self._process_single_item_scoring_results(
                results, label_token_ids, apply_softmax
            )
2004

2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
    async def watch_load_thread(self):
        # Only for dp_controller when dp_size > 1
        if (
            self.server_args.dp_size == 1
            or self.server_args.load_balance_method == "round_robin"
        ):
            return

        while True:
            await asyncio.sleep(self.server_args.load_watch_interval)
            loads = await self.get_load_communicator(GetLoadReqInput())
            load_udpate_req = WatchLoadUpdateReq(loads=loads)
            self.send_to_scheduler.send_pyobj(load_udpate_req)

2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
    def _trace_request_start(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        created_time: Optional[float] = None,
    ):
        if obj.is_single:
            bootstrap_room = (
                obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
            )
            trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
            trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
        else:
            for i in range(len(obj.rid)):
                bootstrap_room = (
                    obj.bootstrap_room[i]
                    if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
                    else None
                )
                trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
                trace_slice_start(
                    "", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
                )

2042

2043
2044
2045
2046
2047
2048
class ServerStatus(Enum):
    Up = "Up"
    Starting = "Starting"
    UnHealthy = "UnHealthy"


2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
    is_cross_node = server_args.dist_init_addr

    if is_cross_node:
        # Fallback to default CPU transport for multi-node
        return "default"
    else:
        return "cuda_ipc"


2059
2060
2061
2062
2063
2064
2065
2066
2067
async def print_exception_wrapper(func):
    """
    Sometimes an asyncio function does not print exception.
    We do another wrapper to handle the exception.
    """
    try:
        await func()
    except Exception:
        traceback = get_exception_traceback()
2068
        logger.error(f"TokenizerManager hit an exception: {traceback}")
2069
2070
        if hasattr(func, "__self__") and isinstance(func.__self__, TokenizerManager):
            func.__self__.dump_requests_before_crash()
2071
2072
2073
2074
        kill_process_tree(os.getpid(), include_parent=True)
        sys.exit(1)


2075
class SignalHandler:
2076
    def __init__(self, tokenizer_manager: TokenizerManager):
2077
        self.tokenizer_manager = tokenizer_manager
2078

2079
    def sigterm_handler(self, signum=None, frame=None):
2080
2081
2082
        logger.warning(
            f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
        )
2083
        self.tokenizer_manager.gracefully_exit = True
2084

2085
2086
    def running_phase_sigquit_handler(self, signum=None, frame=None):
        logger.error(
2087
            f"SIGQUIT received. {signum=}, {frame=}. It usually means one child failed."
2088
        )
2089
        self.tokenizer_manager.dump_requests_before_crash()
2090
2091
        kill_process_tree(os.getpid())

2092

Lianmin Zheng's avatar
Lianmin Zheng committed
2093
2094
2095
2096
2097
# Note: request abort handling logic
# We should handle all of the following cases correctly.
#
# | entrypoint | is_streaming | status          | abort engine    | cancel asyncio task   | rid_to_state                |
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
2098
# | http       | yes          | validation      | background task | fast api              | del in _handle_abort_req    |
Lianmin Zheng's avatar
Lianmin Zheng committed
2099
2100
# | http       | yes          | waiting queue   | background task | fast api              | del in _handle_abort_req    |
# | http       | yes          | running         | background task | fast api              | del in _handle_batch_output |
2101
# | http       | no           | validation      | http exception  | http exception        | del in _handle_abort_req    |
Lianmin Zheng's avatar
Lianmin Zheng committed
2102
2103
2104
# | http       | no           | waiting queue   | type 1          | type 1 exception      | del in _handle_abort_req    |
# | http       | no           | running         | type 3          | type 3 exception      | del in _handle_batch_output |
#