tokenizer_manager.py 73.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
"""TokenizerManager is a process that tokenizes the text."""
15

Lianmin Zheng's avatar
Lianmin Zheng committed
16
import asyncio
17
18
import copy
import dataclasses
19
import json
20
import logging
21
import math
Lianmin Zheng's avatar
Lianmin Zheng committed
22
import os
23
import pickle
24
25
import signal
import sys
26
import threading
27
import time
28
import uuid
29
from collections import deque
30
31
from datetime import datetime
from http import HTTPStatus
32
33
34
35
36
37
38
39
40
41
42
43
from typing import (
    Any,
    Awaitable,
    Deque,
    Dict,
    Generic,
    List,
    Optional,
    Tuple,
    TypeVar,
    Union,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
44

45
import fastapi
46
import torch
Lianmin Zheng's avatar
Lianmin Zheng committed
47
48
49
import uvloop
import zmq
import zmq.asyncio
50
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
51

52
from sglang.srt.aio_rwlock import RWLock
53
from sglang.srt.configs.model_config import ModelConfig
54
55
56
57
58
59
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
    KVClassType,
    TransferBackend,
    get_kv_class,
)
xm:D's avatar
xm:D committed
60
61
62
63
64
from sglang.srt.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
65
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
Lianmin Zheng's avatar
Lianmin Zheng committed
66
from sglang.srt.managers.io_struct import (
67
    AbortReq,
68
    BatchEmbeddingOut,
69
    BatchMultimodalOut,
Lianmin Zheng's avatar
Lianmin Zheng committed
70
    BatchStrOut,
71
    BatchTokenIDOut,
72
    CloseSessionReqInput,
73
    ConfigureLoggingReq,
74
    EmbeddingReqInput,
75
    ExpertDistributionReq,
76
    ExpertDistributionReqOutput,
77
78
    FlushCacheReqInput,
    FlushCacheReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
79
    GenerateReqInput,
80
81
    GetInternalStateReq,
    GetInternalStateReqOutput,
82
83
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
84
    HealthCheckOutput,
85
86
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
87
88
89
    LoadLoRAAdapterReqInput,
    LoadLoRAAdapterReqOutput,
    LoRAUpdateResult,
90
91
    OpenSessionReqInput,
    OpenSessionReqOutput,
92
    ProfileReq,
93
94
    ProfileReqOutput,
    ProfileReqType,
95
96
97
98
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
99
    SessionParams,
100
101
    SetInternalStateReq,
    SetInternalStateReqOutput,
102
103
    SlowDownReqInput,
    SlowDownReqOutput,
104
105
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
106
107
    UnloadLoRAAdapterReqInput,
    UnloadLoRAAdapterReqOutput,
Chayenne's avatar
Chayenne committed
108
109
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
110
111
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
112
113
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
114
)
Mick's avatar
Mick committed
115
from sglang.srt.managers.mm_utils import TensorTransportMode
116
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
117
118
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
119
from sglang.srt.server_args import PortArgs, ServerArgs
120
121
from sglang.srt.utils import (
    dataclass_to_string_truncated,
122
    get_bool_env_var,
123
124
125
    get_zmq_socket,
    kill_process_tree,
)
126
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
Lianmin Zheng's avatar
Lianmin Zheng committed
127
128
129

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

130
131
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
132

133
134
135
136
@dataclasses.dataclass
class ReqState:
    """Store the state a request."""

137
    out_list: List[Dict[Any, Any]]
138
139
    finished: bool
    event: asyncio.Event
140
    obj: Union[GenerateReqInput, EmbeddingReqInput]
141
142
143

    # For metrics
    created_time: float
144
145
146
147
    finished_time: float = 0.0
    first_token_time: float = 0.0
    last_time: float = 0.0
    last_completion_tokens: int = 1
148
149
150

    # For streaming output
    last_output_offset: int = 0
151

152
    # For incremental state update.
153
    # TODO(lianmin): do not initialize some lists if not needed.
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    text: str = ""
    output_ids: List[int] = dataclasses.field(default_factory=list)
    input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
    input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
    output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
    output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
    input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
    input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
    output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
    output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
    input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
    input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
    output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
    output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
168
169


Mick's avatar
Mick committed
170
171
172
173
174
175
176
177
178
179
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
    is_cross_node = server_args.dist_init_addr

    if is_cross_node:
        # Fallback to default CPU transport for multi-node
        return "default"
    else:
        return "cuda_ipc"


180
181
class TokenizerManager:
    """TokenizerManager is a process that tokenizes the text."""
182

Lianmin Zheng's avatar
Lianmin Zheng committed
183
184
185
186
187
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
    ):
188
        # Parse args
Liangsheng Yin's avatar
Liangsheng Yin committed
189
        self.server_args = server_args
190
        self.enable_metrics = server_args.enable_metrics
191
        self.log_requests = server_args.log_requests
192
        self.log_requests_level = server_args.log_requests_level
193
194
195
196
197
        self.preferred_sampling_params = (
            json.loads(server_args.preferred_sampling_params)
            if server_args.preferred_sampling_params
            else None
        )
198
199
        self.crash_dump_folder = server_args.crash_dump_folder
        self.crash_dump_performed = False  # Flag to ensure dump is only called once
Liangsheng Yin's avatar
Liangsheng Yin committed
200

201
        # Init inter-process communication
Lianmin Zheng's avatar
Lianmin Zheng committed
202
        context = zmq.asyncio.Context(2)
203
        self.recv_from_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
204
            context, zmq.PULL, port_args.tokenizer_ipc_name, True
205
206
        )
        self.send_to_scheduler = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
207
            context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
208
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
209

210
        # Read model args
Lianmin Zheng's avatar
Lianmin Zheng committed
211
        self.model_path = server_args.model_path
212
        self.served_model_name = server_args.served_model_name
213
        self.model_config = ModelConfig.from_server_args(server_args)
214
        self.is_generation = self.model_config.is_generation
215
        self.is_image_gen = self.model_config.is_image_gen
216
217
        self.context_len = self.model_config.context_len
        self.image_token_id = self.model_config.image_token_id
218
219
        self._updating = False
        self._cond = asyncio.Condition()
220

221
        if self.model_config.is_multimodal:
Mick's avatar
Mick committed
222
            import_processors()
223
224
225
226
227
            _processor = get_processor(
                server_args.tokenizer_path,
                tokenizer_mode=server_args.tokenizer_mode,
                trust_remote_code=server_args.trust_remote_code,
                revision=server_args.revision,
228
                use_fast=not server_args.disable_fast_image_processor,
229
            )
Mick's avatar
Mick committed
230
            transport_mode = _determine_tensor_transport_mode(self.server_args)
231
232

            # We want to parallelize the image pre-processing so we create an executor for it
Mick's avatar
Mick committed
233
            # We create mm_processor for any skip_tokenizer_init to make sure we still encode
234
            # images even with skip_tokenizer_init=False.
Mick's avatar
Mick committed
235
            self.mm_processor = get_mm_processor(
Mick's avatar
Mick committed
236
                self.model_config.hf_config, server_args, _processor, transport_mode
237
238
239
240
241
242
            )

            if server_args.skip_tokenizer_init:
                self.tokenizer = self.processor = None
            else:
                self.processor = _processor
xm:D's avatar
xm:D committed
243
                self.tokenizer = get_tokenizer_from_processor(self.processor)
244
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
245
        else:
246
            self.mm_processor = None
247

248
249
            if server_args.skip_tokenizer_init:
                self.tokenizer = self.processor = None
250
251
252
253
254
255
256
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
257

258
259
260
261
262
        # Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
        # The registry dynamically updates as adapters are loaded / unloaded during runtime. It
        # serves as the source of truth for available adapters and maps user-friendly LoRA names
        # to internally used unique LoRA IDs.
        self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
263

264
        # Store states
265
        self.no_create_loop = False
266
        self.rid_to_state: Dict[str, ReqState] = {}
267
        self.health_check_failed = False
268
269
        self.gracefully_exit = False
        self.last_receive_tstamp = 0
270
271
272
        self.dump_requests_folder = ""  # By default do not dump
        self.dump_requests_threshold = 1000
        self.dump_request_list: List[Tuple] = []
273
        self.crash_dump_request_list: deque[Tuple] = deque()
274
        self.log_request_metadata = self.get_log_request_metadata()
275
276
        self.session_futures = {}  # session_id -> asyncio event
        self.max_req_input_len = None
277
        self.asyncio_tasks = set()
Lianmin Zheng's avatar
Lianmin Zheng committed
278

279
280
281
282
283
        # The event to notify the weight sync is finished.
        self.model_update_lock = RWLock()
        self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
            None
        )
284

285
286
287
288
        # For pd disaggregtion
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
289
        self.disaggregation_transfer_backend = TransferBackend(
290
291
292
293
294
295
            self.server_args.disaggregation_transfer_backend
        )
        # Start kv boostrap server on prefill
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            # only start bootstrap server on prefill tm
            kv_bootstrap_server_class = get_kv_class(
296
                self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
297
298
299
300
            )
            self.bootstrap_server = kv_bootstrap_server_class(
                self.server_args.disaggregation_bootstrap_port
            )
301
302
303
304
305
306
307
308
309
310
311
312
313
314
            is_create_store = (
                self.server_args.node_rank == 0
                and self.server_args.disaggregation_transfer_backend == "ascend"
            )
            if is_create_store:
                try:
                    from mf_adapter import create_config_store

                    ascend_url = os.getenv("ASCEND_MF_STORE_URL")
                    create_config_store(ascend_url)
                except Exception as e:
                    error_message = f"Failed create mf store, invalid ascend_url."
                    error_message += f" With exception {e}"
                    raise error_message
315

316
317
318
        # For load balancing
        self.current_load = 0
        self.current_load_lock = asyncio.Lock()
319
320
321
322
323
324

        # Metrics
        if self.enable_metrics:
            self.metrics_collector = TokenizerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
325
                    # TODO: Add lora name/path in the future,
326
                },
327
328
329
330
                bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
                bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
                bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
                collect_tokens_histogram=self.server_args.collect_tokens_histogram,
331
332
333
            )

        # Communicators
334
335
336
337
338
339
        self.init_weights_update_group_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.update_weights_from_distributed_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
340
341
342
        self.update_weights_from_tensor_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
343
344
345
        self.get_weights_by_name_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
346
347
348
349
350
351
        self.release_memory_occupation_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.resume_memory_occupation_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
352
353
354
        self.slow_down_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
355
356
357
        self.flush_cache_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
358
        self.profile_communicator = _Communicator(
359
360
361
362
363
            self.send_to_scheduler, server_args.dp_size
        )
        self.get_internal_state_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
364
365
366
        self.set_internal_state_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
367
368
369
        self.expert_distribution_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
370
371
372
        self.update_lora_adapter_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
373

374
        self._result_dispatcher = TypeBasedDispatcher(
375
            [
376
                (
377
378
379
380
381
382
                    (
                        BatchStrOut,
                        BatchEmbeddingOut,
                        BatchTokenIDOut,
                        BatchMultimodalOut,
                    ),
383
                    self._handle_batch_output,
384
                ),
Lianmin Zheng's avatar
Lianmin Zheng committed
385
                (AbortReq, self._handle_abort_req),
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
                (OpenSessionReqOutput, self._handle_open_session_req_output),
                (
                    UpdateWeightFromDiskReqOutput,
                    self._handle_update_weights_from_disk_req_output,
                ),
                (
                    InitWeightsUpdateGroupReqOutput,
                    self.init_weights_update_group_communicator.handle_recv,
                ),
                (
                    UpdateWeightsFromDistributedReqOutput,
                    self.update_weights_from_distributed_communicator.handle_recv,
                ),
                (
                    UpdateWeightsFromTensorReqOutput,
                    self.update_weights_from_tensor_communicator.handle_recv,
                ),
                (
                    GetWeightsByNameReqOutput,
                    self.get_weights_by_name_communicator.handle_recv,
                ),
                (
                    ReleaseMemoryOccupationReqOutput,
                    self.release_memory_occupation_communicator.handle_recv,
                ),
                (
                    ResumeMemoryOccupationReqOutput,
                    self.resume_memory_occupation_communicator.handle_recv,
                ),
415
416
417
418
                (
                    SlowDownReqOutput,
                    self.slow_down_communicator.handle_recv,
                ),
419
420
421
422
                (
                    FlushCacheReqOutput,
                    self.flush_cache_communicator.handle_recv,
                ),
423
424
                (
                    ProfileReqOutput,
425
                    self.profile_communicator.handle_recv,
426
427
428
429
430
                ),
                (
                    GetInternalStateReqOutput,
                    self.get_internal_state_communicator.handle_recv,
                ),
431
432
433
434
                (
                    SetInternalStateReqOutput,
                    self.set_internal_state_communicator.handle_recv,
                ),
435
436
437
438
                (
                    ExpertDistributionReqOutput,
                    self.expert_distribution_communicator.handle_recv,
                ),
439
440
441
442
                (
                    LoRAUpdateResult,
                    self.update_lora_adapter_communicator.handle_recv,
                ),
443
                (HealthCheckOutput, lambda x: None),
444
445
446
            ]
        )

447
    async def generate_request(
448
        self,
449
        obj: Union[GenerateReqInput, EmbeddingReqInput],
450
        request: Optional[fastapi.Request] = None,
451
    ):
452
        created_time = time.time()
453
454
455
        async with self._cond:
            await self._cond.wait_for(lambda: not self._updating)

456
        self.auto_create_handle_loop()
457
        obj.normalize_batch_and_arguments()
458
459
460
461
462
463
464
465

        if isinstance(obj, EmbeddingReqInput) and self.is_generation:
            raise ValueError(
                "This model does not appear to be an embedding model by default. "
                "Please add `--is-embedding` when launching the server or try another model."
            )

        if self.log_requests:
466
            max_length, skip_names, _ = self.log_request_metadata
467
            logger.info(
468
                f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
469
470
            )

471
        async with self.model_update_lock.reader_lock:
472
            if obj.is_single:
473
                tokenized_obj = await self._tokenize_one_request(obj)
474
475
                state = self._send_one_request(obj, tokenized_obj, created_time)
                async for response in self._wait_one_response(obj, state, request):
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
                    yield response
            else:
                async for response in self._handle_batch_request(
                    obj, request, created_time
                ):
                    yield response

    async def _tokenize_one_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
    ):
        """Tokenize one request."""
        # Tokenize
        input_embeds = None
        input_text = obj.text
woodx's avatar
woodx committed
491
492
493
494
        token_type_ids = None
        is_cross_encoder_request = (
            isinstance(obj, EmbeddingReqInput) and obj.is_cross_encoder_request
        )
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
        if obj.input_embeds is not None:
            if not self.server_args.disable_radix_cache:
                raise ValueError(
                    "input_embeds is provided while disable_radix_cache is False. "
                    "Please add `--disable-radix-cache` when you launch the server "
                    "if you want to use input_embeds as inputs."
                )
            input_embeds = obj.input_embeds
            input_ids = obj.input_ids
        elif obj.input_ids is not None:
            input_ids = obj.input_ids
        else:
            if self.tokenizer is None:
                raise ValueError(
                    "The engine initialized with skip_tokenizer_init=True cannot "
                    "accept text prompts. Please provide input_ids or re-initialize "
                    "the engine with skip_tokenizer_init=False."
                )
woodx's avatar
woodx committed
513
514
515
516
517
518
519
520
            encoded = self.tokenizer(
                input_text, return_token_type_ids=is_cross_encoder_request
            )

            input_ids = encoded["input_ids"]
            if is_cross_encoder_request:
                input_ids = encoded["input_ids"][0]
                token_type_ids = encoded.get("token_type_ids", [None])[0]
521

522
        if self.mm_processor and obj.contains_mm_input():
523
524
525
526
527
            if not isinstance(obj.image_data, list):
                obj.image_data = [obj.image_data]
            if not isinstance(obj.audio_data, list):
                obj.audio_data = [obj.audio_data]
            mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
528
                image_data=obj.image_data,
529
                audio_data=obj.audio_data,
530
531
532
533
                input_text=input_text or input_ids,
                request_obj=obj,
                max_req_input_len=self.max_req_input_len,
            )
534
535
            if mm_inputs and "input_ids" in mm_inputs:
                input_ids = mm_inputs["input_ids"]
536
        else:
537
            mm_inputs = None
538

539
540
541
542
        if self.server_args.enable_lora and obj.lora_path:
            # Replace the user-friendly LoRA names in `lora_path` with their corresponding unique LoRA IDs.
            obj.lora_path = await self.lora_registry.acquire(obj.lora_path)

543
        self._validate_one_request(obj, input_ids)
544
        return self._create_tokenized_object(
545
            obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
546
547
        )

548
    def _validate_one_request(
549
550
551
        self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
    ) -> None:
        """Validates that the input token count and the requested token count doesn't exceed the model's context length."""
552
553

        input_token_num = len(input_ids) if input_ids is not None else 0
554
        # Check if input alone exceeds context length
555
556
557
558
559
560
        if input_token_num >= self.context_len:
            raise ValueError(
                f"The input ({input_token_num} tokens) is longer than the "
                f"model's context length ({self.context_len} tokens)."
            )

561
562
        # Check total tokens (input + max_new_tokens)
        max_new_tokens = obj.sampling_params.get("max_new_tokens")
563
        if (
564
565
            max_new_tokens is not None
            and (max_new_tokens + input_token_num) >= self.context_len
566
        ):
567
568
            total_tokens = max_new_tokens + input_token_num
            error_msg = (
569
                f"Requested token count exceeds the model's maximum context length "
570
                f"of {self.context_len} tokens. You requested a total of {total_tokens} "
571
                f"tokens: {input_token_num} tokens from the input messages and "
572
573
574
575
576
                f"{max_new_tokens} tokens for the completion. Please reduce the number "
                f"of tokens in the input messages or the completion to fit within the limit."
            )
            raise ValueError(error_msg)

577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
        if isinstance(obj, GenerateReqInput):
            if (
                obj.return_hidden_states
                and not self.server_args.enable_return_hidden_states
            ):
                raise ValueError(
                    "The server is not configured to return the hidden states. "
                    "Please set `--enable-return-hidden-states` to enable this feature."
                )
            if (
                obj.custom_logit_processor
                and not self.server_args.enable_custom_logit_processor
            ):
                raise ValueError(
                    "The server is not configured to enable custom logit processor. "
                    "Please set `--enable-custom-logits-processor` to enable this feature."
                )

595
596
597
598
599
600
601
602
    def _validate_input_ids_in_vocab(
        self, input_ids: List[int], vocab_size: int
    ) -> None:
        if any(id >= vocab_size for id in input_ids):
            raise ValueError(
                f"The input_ids {input_ids} contains values greater than the vocab size ({vocab_size})."
            )

603
604
605
606
607
608
    def _create_tokenized_object(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        input_text: str,
        input_ids: List[int],
        input_embeds: Optional[Union[List[float], None]] = None,
609
        mm_inputs: Optional[Dict] = None,
woodx's avatar
woodx committed
610
        token_type_ids: Optional[List[int]] = None,
611
612
    ) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
        """Create a tokenized request object from common parameters."""
613
614
615
616
617
618
619
620
        # Parse sampling parameters
        # Note: if there are preferred sampling params, we use them if they are not
        # explicitly passed in sampling_params
        if self.preferred_sampling_params:
            sampling_kwargs = {**self.preferred_sampling_params, **obj.sampling_params}
        else:
            sampling_kwargs = obj.sampling_params
        sampling_params = SamplingParams(**sampling_kwargs)
621
        sampling_params.normalize(self.tokenizer)
622
        sampling_params.verify(self.model_config.vocab_size)
623
624
625

        # Build return object
        if isinstance(obj, GenerateReqInput):
626
627
628
629
            session_params = (
                SessionParams(**obj.session_params) if obj.session_params else None
            )

630
631
632
633
            tokenized_obj = TokenizedGenerateReqInput(
                obj.rid,
                input_text,
                input_ids,
634
                mm_inputs,
635
                sampling_params,
636
637
638
639
                obj.return_logprob,
                obj.logprob_start_len,
                obj.top_logprobs_num,
                obj.token_ids_logprob,
640
                obj.stream,
641
                bootstrap_host=obj.bootstrap_host,
642
                bootstrap_port=obj.bootstrap_port,
643
                bootstrap_room=obj.bootstrap_room,
644
645
646
647
                lora_path=obj.lora_path,
                input_embeds=input_embeds,
                session_params=session_params,
                custom_logit_processor=obj.custom_logit_processor,
648
                return_hidden_states=obj.return_hidden_states,
649
                data_parallel_rank=obj.data_parallel_rank,
650
651
652
653
654
655
            )
        elif isinstance(obj, EmbeddingReqInput):
            tokenized_obj = TokenizedEmbeddingReqInput(
                obj.rid,
                input_text,
                input_ids,
656
                mm_inputs,
woodx's avatar
woodx committed
657
                token_type_ids,
658
659
660
661
662
                sampling_params,
            )

        return tokenized_obj

663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
    async def _batch_tokenize_and_process(
        self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
    ) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]:
        """Handle batch tokenization for text inputs only."""
        logger.debug(f"Starting batch tokenization for {batch_size} text requests")

        # Collect requests and texts
        requests = [obj[i] for i in range(batch_size)]
        texts = [req.text for req in requests]

        # Batch tokenize all texts
        encoded = self.tokenizer(texts)
        input_ids_list = encoded["input_ids"]

        # Process all requests
        tokenized_objs = []
        for i, req in enumerate(requests):
            self._validate_token_len(obj[i], input_ids_list[i])
            tokenized_objs.append(
                self._create_tokenized_object(
                    req, req.text, input_ids_list[i], None, None
                )
            )
        logger.debug(f"Completed batch processing for {batch_size} requests")
        return tokenized_objs

    def _validate_batch_tokenization_constraints(
        self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
    ) -> None:
        """Validate constraints for batch tokenization processing."""
        for i in range(batch_size):
694
            if self.is_generation and obj[i].contains_mm_input():
695
                raise ValueError(
696
                    "For multimodal input processing do not set `enable_tokenizer_batch_encode`."
697
698
699
700
701
702
703
704
705
706
                )
            if obj[i].input_ids is not None:
                raise ValueError(
                    "Batch tokenization is not needed for pre-tokenized input_ids. Do not set `enable_tokenizer_batch_encode`."
                )
            if obj[i].input_embeds is not None:
                raise ValueError(
                    "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
                )

707
708
709
710
711
712
    def _send_one_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
        created_time: Optional[float] = None,
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
713
        self.send_to_scheduler.send_pyobj(tokenized_obj)
714
        state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
715
        self.rid_to_state[obj.rid] = state
716
        return state
717
718
719
720

    async def _wait_one_response(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
721
        state: ReqState,
722
723
724
725
726
727
728
729
        request: Optional[fastapi.Request] = None,
    ):
        """Wait for the response of one request."""
        while True:
            try:
                await asyncio.wait_for(state.event.wait(), timeout=4)
            except asyncio.TimeoutError:
                if request is not None and await request.is_disconnected():
Lianmin Zheng's avatar
Lianmin Zheng committed
730
                    # Abort the request for disconnected requests (non-streaming, waiting queue)
731
                    self.abort_request(obj.rid)
Lianmin Zheng's avatar
Lianmin Zheng committed
732
                    # Use exception to kill the whole call stack and asyncio task
733
                    raise ValueError(
Lianmin Zheng's avatar
Lianmin Zheng committed
734
                        f"Request is disconnected from the client side (type 1). Abort request {obj.rid=}"
735
                    )
736
737
738
739
740
741
742
                continue

            out = state.out_list[-1]

            state.out_list = []
            if state.finished:
                if self.log_requests:
743
744
745
746
747
                    max_length, skip_names, out_skip_names = self.log_request_metadata
                    if self.model_config.is_multimodal_gen:
                        msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
                    else:
                        msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
                    logger.info(msg)

                # Check if this was an abort/error created by scheduler
                if isinstance(out["meta_info"].get("finish_reason"), dict):
                    finish_reason = out["meta_info"]["finish_reason"]
                    if (
                        finish_reason.get("type") == "abort"
                        and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST
                    ):
                        raise ValueError(finish_reason["message"])

                yield out
                break

            state.event.clear()

            if obj.stream:
                yield out
            else:
                if request is not None and await request.is_disconnected():
Lianmin Zheng's avatar
Lianmin Zheng committed
768
                    # Abort the request for disconnected requests (non-streaming, running)
769
                    self.abort_request(obj.rid)
Lianmin Zheng's avatar
Lianmin Zheng committed
770
                    # Use exception to kill the whole call stack and asyncio task
771
                    raise ValueError(
Lianmin Zheng's avatar
Lianmin Zheng committed
772
                        f"Request is disconnected from the client side (type 3). Abort request {obj.rid=}"
773
                    )
774
775
776
777
778
779
780
781
782
783
784
785

    async def _handle_batch_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
        created_time: Optional[float] = None,
    ):
        batch_size = obj.batch_size

        generators = []
        rids = []
        if getattr(obj, "parallel_sample_num", 1) == 1:
786
787
788
789
790
791
792
793
            if self.server_args.enable_tokenizer_batch_encode:
                # Validate batch tokenization constraints
                self._validate_batch_tokenization_constraints(batch_size, obj)

                tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)

                for i, tokenized_obj in enumerate(tokenized_objs):
                    tmp_obj = obj[i]
794
795
                    state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, state, request))
796
797
798
799
800
801
                    rids.append(tmp_obj.rid)
            else:
                # Sequential tokenization and processing
                for i in range(batch_size):
                    tmp_obj = obj[i]
                    tokenized_obj = await self._tokenize_one_request(tmp_obj)
802
803
                    state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, state, request))
804
                    rids.append(tmp_obj.rid)
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
        else:
            # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
            if batch_size > 128:
                logger.warning(
                    "Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
                    "The performance might be better if you just duplicate the requests n times or use "
                    "many threads to send them one by one with parallel sampling (n > 1)."
                )

            # Tokenize all requests
            objs = [obj[i] for i in range(batch_size)]
            tokenized_objs = await asyncio.gather(
                *(self._tokenize_one_request(obj) for obj in objs)
            )

            # Cache the common prefix for parallel sampling
            for i in range(batch_size):
                tmp_obj = copy.copy(objs[i])
                tokenized_obj = copy.copy(tokenized_objs[i])
                tokenized_obj.rid = tmp_obj.regenerate_rid()
                tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
                tokenized_obj.sampling_params.max_new_tokens = 0
                tokenized_obj.stream = False
828
829
                state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
                await self._wait_one_response(tmp_obj, state, request).__anext__()
830
831
832
833
834
835
836

            # Expand requests, assign new rids for them, and send them
            for i in range(batch_size):
                for _ in range(obj.parallel_sample_num):
                    tmp_obj = copy.copy(objs[i])
                    tokenized_obj = copy.copy(tokenized_objs[i])
                    tokenized_obj.rid = tmp_obj.regenerate_rid()
837
838
                    state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, state, request))
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
                    rids.append(tmp_obj.rid)

        # Wait for all requests
        is_stream = hasattr(obj, "stream") and obj.stream
        if not is_stream:
            outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
            yield outputs
        else:
            rid_to_index = {rid: i for i, rid in enumerate(rids)}
            task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
            while task_map:
                done, _ = await asyncio.wait(
                    task_map.keys(), return_when=asyncio.FIRST_COMPLETED
                )

                for task in done:
                    gen = task_map.pop(task)
                    try:
                        result = task.result()
                        result["index"] = rid_to_index[result["meta_info"]["id"]]
                        yield result
                        new_task = asyncio.create_task(gen.__anext__())
                        task_map[new_task] = gen
                    except StopAsyncIteration:
                        pass
864

865
    async def flush_cache(self) -> FlushCacheReqOutput:
Lianmin Zheng's avatar
Lianmin Zheng committed
866
        return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
Liangsheng Yin's avatar
Liangsheng Yin committed
867

868
869
    def abort_request(self, rid: str = "", abort_all: bool = False):
        if not abort_all and rid not in self.rid_to_state:
870
            return
871
        req = AbortReq(rid, abort_all)
872
        self.send_to_scheduler.send_pyobj(req)
873

874
875
876
        if self.enable_metrics:
            self.metrics_collector.observe_one_aborted_request()

877
878
879
    async def start_profile(
        self,
        output_dir: Optional[str] = None,
880
        start_step: Optional[int] = None,
881
882
        num_steps: Optional[int] = None,
        activities: Optional[List[str]] = None,
883
884
        with_stack: Optional[bool] = None,
        record_shapes: Optional[bool] = None,
885
        profile_by_stage: bool = False,
886
    ):
887
        self.auto_create_handle_loop()
888
889
        env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
        with_stack = False if with_stack is False or env_with_stack is False else True
890
891
892
        req = ProfileReq(
            type=ProfileReqType.START_PROFILE,
            output_dir=output_dir,
893
            start_step=start_step,
894
895
            num_steps=num_steps,
            activities=activities,
896
897
            with_stack=with_stack,
            record_shapes=record_shapes,
898
            profile_by_stage=profile_by_stage,
899
            profile_id=str(time.time()),
900
        )
901
902
903
        return await self._execute_profile(req)

    async def stop_profile(self):
904
        self.auto_create_handle_loop()
905
906
907
908
909
        req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
        return await self._execute_profile(req)

    async def _execute_profile(self, req: ProfileReq):
        result = (await self.profile_communicator(req))[0]
910
911
912
        if not result.success:
            raise RuntimeError(result.message)
        return result
913

914
    async def start_expert_distribution_record(self):
915
        self.auto_create_handle_loop()
916
        await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
917

918
    async def stop_expert_distribution_record(self):
919
        self.auto_create_handle_loop()
920
        await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
921

922
    async def dump_expert_distribution_record(self):
923
        self.auto_create_handle_loop()
924
        await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
925

926
927
928
929
930
931
932
933
934
935
    async def pause_generation(self):
        async with self._cond:
            self._updating = True
            self.abort_request(abort_all=True)

    async def continue_generation(self):
        async with self._cond:
            self._updating = False
            self._cond.notify_all()

Chayenne's avatar
Chayenne committed
936
937
938
939
    async def update_weights_from_disk(
        self,
        obj: UpdateWeightFromDiskReqInput,
        request: Optional[fastapi.Request] = None,
940
    ) -> Tuple[bool, str]:
941
        self.auto_create_handle_loop()
942
943
944
945

        # default the load format to the server_args
        if obj.load_format is None:
            obj.load_format = self.server_args.load_format
946
        logger.info("Start update_weights. Load format=%s", obj.load_format)
947

948
949
950
        if obj.abort_all_requests:
            self.abort_request(abort_all=True)

951
        if True:  # Keep this redundant check to simplify some internal code sync
952
953
954
955
            # Hold the lock if it is not async. This means that weight sync
            # cannot run while requests are in progress.
            async with self.model_update_lock.writer_lock:
                return await self._wait_for_model_update_from_disk(obj)
956

957
958
    async def _wait_for_model_update_from_disk(
        self, obj: UpdateWeightFromDiskReqInput
959
    ) -> Tuple[bool, str]:
960
961
962
963
964
965
966
967
968
        self.send_to_scheduler.send_pyobj(obj)
        self.model_update_result = asyncio.Future()
        if self.server_args.dp_size == 1:
            result = await self.model_update_result
            if result.success:
                self.served_model_name = obj.model_path
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
969
            return result.success, result.message, result.num_paused_requests
970
971
972
973
974
975
976
977
978
979
980
        else:  # self.server_args.dp_size > 1
            self.model_update_tmp = []
            result = await self.model_update_result

            all_success = all([r.success for r in result])
            if all_success is True:
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
            all_message = [r.message for r in result]
            all_message = " | ".join(all_message)
981
982
            all_paused_requests = [r.num_paused_requests for r in result]
            return all_success, all_message, all_paused_requests
983

984
985
986
987
    async def init_weights_update_group(
        self,
        obj: InitWeightsUpdateGroupReqInput,
        request: Optional[fastapi.Request] = None,
988
    ) -> Tuple[bool, str]:
989
        self.auto_create_handle_loop()
990
991
992
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be 1 for init parameter update group"
993
        result = (await self.init_weights_update_group_communicator(obj))[0]
994
995
996
997
998
999
        return result.success, result.message

    async def update_weights_from_distributed(
        self,
        obj: UpdateWeightsFromDistributedReqInput,
        request: Optional[fastapi.Request] = None,
1000
    ) -> Tuple[bool, str]:
1001
1002
        self.auto_create_handle_loop()
        assert (
1003
1004
            self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
        ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
1005

1006
1007
1008
        if obj.abort_all_requests:
            self.abort_request(abort_all=True)

1009
1010
1011
        # This means that weight sync
        # cannot run while requests are in progress.
        async with self.model_update_lock.writer_lock:
1012
            result = (await self.update_weights_from_distributed_communicator(obj))[0]
1013
            return result.success, result.message
1014

1015
1016
1017
1018
1019
1020
1021
    async def update_weights_from_tensor(
        self,
        obj: UpdateWeightsFromTensorReqInput,
        request: Optional[fastapi.Request] = None,
    ) -> Tuple[bool, str]:
        self.auto_create_handle_loop()
        assert (
1022
1023
            self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
        ), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
1024

1025
1026
1027
        if obj.abort_all_requests:
            self.abort_request(abort_all=True)

1028
1029
1030
1031
1032
1033
        # This means that weight sync
        # cannot run while requests are in progress.
        async with self.model_update_lock.writer_lock:
            result = (await self.update_weights_from_tensor_communicator(obj))[0]
            return result.success, result.message

1034
1035
1036
1037
1038
1039
    async def load_lora_adapter(
        self,
        obj: LoadLoRAAdapterReqInput,
        _: Optional[fastapi.Request] = None,
    ) -> LoadLoRAAdapterReqOutput:
        self.auto_create_handle_loop()
1040
1041
1042
1043
        if not self.server_args.enable_lora:
            raise ValueError(
                "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
            )
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056

        # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
        # with dp_size > 1.
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be 1 for dynamic lora loading"
        logger.info(
            "Start load Lora adapter. Lora name=%s, path=%s",
            obj.lora_name,
            obj.lora_path,
        )

        async with self.model_update_lock.writer_lock:
1057
1058
1059
1060
1061
1062
1063
1064
            # Generate new uniquely identifiable LoRARef object.
            new_adapter = LoRARef(
                lora_name=obj.lora_name,
                lora_path=obj.lora_path,
            )

            # Register the new adapter in the registry.
            obj.lora_id = new_adapter.lora_id
1065
            result = (await self.update_lora_adapter_communicator(obj))[0]
1066
1067
1068
            if result.success:
                await self.lora_registry.register(new_adapter)

1069
1070
1071
1072
1073
1074
1075
1076
            return result

    async def unload_lora_adapter(
        self,
        obj: UnloadLoRAAdapterReqInput,
        _: Optional[fastapi.Request] = None,
    ) -> UnloadLoRAAdapterReqOutput:
        self.auto_create_handle_loop()
1077
1078
1079
1080
        if not self.server_args.enable_lora:
            raise ValueError(
                "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
            )
1081

1082
1083
1084
1085
        assert (
            obj.lora_name is not None
        ), "lora_name must be provided to unload LoRA adapter"

1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
        # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
        # with dp_size > 1.
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be 1 for dynamic lora loading"
        logger.info(
            "Start unload Lora adapter. Lora name=%s",
            obj.lora_name,
        )

        async with self.model_update_lock.writer_lock:
1097
            obj.lora_id = await self.lora_registry.unregister(obj.lora_name)
1098
            result = (await self.update_lora_adapter_communicator(obj))[0]
1099

1100
1101
            return result

1102
1103
1104
    async def get_weights_by_name(
        self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
    ):
1105
1106
1107
        self.auto_create_handle_loop()
        results = await self.get_weights_by_name_communicator(obj)
        all_parameters = [r.parameter for r in results]
1108
        if self.server_args.dp_size == 1:
1109
            return all_parameters[0]
1110
1111
1112
        else:
            return all_parameters

1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
    async def release_memory_occupation(
        self,
        obj: ReleaseMemoryOccupationReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.release_memory_occupation_communicator(obj)

    async def resume_memory_occupation(
        self,
        obj: ResumeMemoryOccupationReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.resume_memory_occupation_communicator(obj)

1129
1130
1131
1132
1133
1134
1135
1136
    async def slow_down(
        self,
        obj: SlowDownReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.slow_down_communicator(obj)

1137
1138
1139
    async def open_session(
        self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
    ):
1140
        self.auto_create_handle_loop()
1141

1142
1143
1144
1145
1146
        if obj.session_id is None:
            obj.session_id = uuid.uuid4().hex
        elif obj.session_id in self.session_futures:
            return None

1147
        self.send_to_scheduler.send_pyobj(obj)
1148
1149
1150
1151

        self.session_futures[obj.session_id] = asyncio.Future()
        session_id = await self.session_futures[obj.session_id]
        del self.session_futures[obj.session_id]
1152
1153
1154
1155
1156
1157
1158
        return session_id

    async def close_session(
        self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
    ):
        await self.send_to_scheduler.send_pyobj(obj)

1159
    async def get_internal_state(self) -> List[Dict[Any, Any]]:
1160
        req = GetInternalStateReq()
1161
        responses: List[GetInternalStateReqOutput] = (
1162
1163
            await self.get_internal_state_communicator(req)
        )
1164
1165
        # Many DP ranks
        return [res.internal_state for res in responses]
1166

Liangsheng Yin's avatar
Liangsheng Yin committed
1167
1168
1169
1170
1171
1172
1173
1174
    async def get_load(self) -> dict:
        # TODO(lsyin): fake load report server
        if not self.current_load_lock.locked():
            async with self.current_load_lock:
                internal_state = await self.get_internal_state()
                self.current_load = internal_state[0]["load"]
        return {"load": self.current_load}

1175
1176
1177
1178
1179
1180
1181
1182
    async def set_internal_state(
        self, obj: SetInternalStateReq
    ) -> SetInternalStateReqOutput:
        responses: List[SetInternalStateReqOutput] = (
            await self.set_internal_state_communicator(obj)
        )
        return [res.internal_state for res in responses]

1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
    def get_log_request_metadata(self):
        max_length = None
        skip_names = None
        out_skip_names = None
        if self.log_requests:
            if self.log_requests_level == 0:
                max_length = 1 << 30
                skip_names = set(
                    [
                        "text",
                        "input_ids",
                        "input_embeds",
                        "image_data",
                        "audio_data",
                        "lora_path",
1198
1199
1200
1201
1202
1203
1204
                        "sampling_params",
                    ]
                )
                out_skip_names = set(
                    [
                        "text",
                        "output_ids",
1205
                        "embedding",
1206
1207
1208
                    ]
                )
            elif self.log_requests_level == 1:
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
                max_length = 1 << 30
                skip_names = set(
                    [
                        "text",
                        "input_ids",
                        "input_embeds",
                        "image_data",
                        "audio_data",
                        "lora_path",
                    ]
                )
                out_skip_names = set(
                    [
                        "text",
                        "output_ids",
1224
                        "embedding",
1225
1226
                    ]
                )
1227
            elif self.log_requests_level == 2:
1228
1229
                max_length = 2048
            elif self.log_requests_level == 3:
1230
1231
1232
1233
1234
1235
1236
                max_length = 1 << 30
            else:
                raise ValueError(
                    f"Invalid --log-requests-level: {self.log_requests_level=}"
                )
        return max_length, skip_names, out_skip_names

1237
    def configure_logging(self, obj: ConfigureLoggingReq):
1238
1239
1240
1241
1242
1243
1244
1245
        if obj.log_requests is not None:
            self.log_requests = obj.log_requests
        if obj.log_requests_level is not None:
            self.log_requests_level = obj.log_requests_level
        if obj.dump_requests_folder is not None:
            self.dump_requests_folder = obj.dump_requests_folder
        if obj.dump_requests_threshold is not None:
            self.dump_requests_threshold = obj.dump_requests_threshold
1246
1247
        if obj.crash_dump_folder is not None:
            self.crash_dump_folder = obj.crash_dump_folder
1248
        logging.info(f"Config logging: {obj=}")
1249
        self.log_request_metadata = self.get_log_request_metadata()
1250

Lianmin Zheng's avatar
Lianmin Zheng committed
1251
    def create_abort_task(self, obj: GenerateReqInput):
1252
1253
        # Abort the request if the client is disconnected.
        async def abort_request():
Lianmin Zheng's avatar
Lianmin Zheng committed
1254
            await asyncio.sleep(2)
1255
1256
1257
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
1258
                for rid in obj.rid:
1259
1260
1261
1262
1263
1264
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

1265
    def auto_create_handle_loop(self):
1266
        if self.no_create_loop:
1267
1268
            return

1269
        self.no_create_loop = True
Lianmin Zheng's avatar
Lianmin Zheng committed
1270
        loop = asyncio.get_event_loop()
1271
1272
1273
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.handle_loop))
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1274

1275
1276
        self.event_loop = loop

1277
1278
1279
1280
        # We cannot add signal handler when the tokenizer manager is not in
        # the main thread due to the CPython limitation.
        if threading.current_thread() is threading.main_thread():
            signal_handler = SignalHandler(self)
1281
1282
1283
1284
1285
            loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
            # Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
            loop.add_signal_handler(
                signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
            )
1286
1287
1288
1289
1290
1291
        else:
            logger.warning(
                "Signal handler is not added because the tokenizer manager is "
                "not in the main thread. This disables graceful shutdown of the "
                "tokenizer manager when SIGTERM is received."
            )
1292
1293
1294
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
        )
1295

1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
    def dump_requests_before_crash(self):
        if self.crash_dump_performed:
            logger.info(
                "SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
            )
            return
        logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
        self.crash_dump_performed = True
        if not self.crash_dump_folder:
            return

        data_to_dump = []
        if self.crash_dump_request_list:
            data_to_dump.extend(self.crash_dump_request_list)

        # Add unfinished requests from rid_to_state
        unfinished_requests = []
        for rid, state in self.rid_to_state.items():
            if not state.finished:
                unfinished_requests.append(
                    (state.obj, {}, state.created_time, time.time())
                )
        if unfinished_requests:
            data_to_dump.extend(unfinished_requests)

        if not data_to_dump:
            return

        filename = os.path.join(
            self.crash_dump_folder,
            os.getenv("HOSTNAME", None),
1327
            f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl",
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
        )

        os.makedirs(os.path.dirname(filename), exist_ok=True)
        # Include server_args in the dump
        data_to_dump_with_server_args = {
            "server_args": self.server_args,
            "requests": data_to_dump,
        }
        with open(filename, "wb") as f:
            pickle.dump(data_to_dump_with_server_args, f)
        logger.error(
            f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}"
        )

1342
1343
    async def sigterm_watchdog(self):
        while not self.gracefully_exit:
1344
            await asyncio.sleep(5)
1345

1346
        # Drain requests
1347
        while True:
1348
            remain_num_req = len(self.rid_to_state)
1349

1350
            if self.health_check_failed:
1351
                # if health check failed, we should exit immediately
1352
1353
1354
1355
                logger.error(
                    "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
                    remain_num_req,
                )
1356
                self.dump_requests_before_crash()
1357
                break
1358
1359
1360
1361
1362
1363
1364
1365

            elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
                # if force shutdown flag set, exit immediately
                logger.error(
                    "Signal SIGTERM received while force shutdown flag set. Force exiting... remaining number of requests: %d",
                    remain_num_req,
                )
                break
1366

1367
            logger.info(
1368
                f"Gracefully exiting... remaining number of requests {remain_num_req}"
1369
1370
1371
1372
            )
            if remain_num_req > 0:
                await asyncio.sleep(5)
            else:
1373
                self.dump_requests_before_crash()
1374
1375
                break

1376
        kill_process_tree(os.getpid(), include_parent=True)
1377
        sys.exit(0)
1378

Lianmin Zheng's avatar
Lianmin Zheng committed
1379
    async def handle_loop(self):
1380
1381
        """The event loop that handles requests"""

Lianmin Zheng's avatar
Lianmin Zheng committed
1382
        while True:
1383
            recv_obj = await self.recv_from_detokenizer.recv_pyobj()
1384
            self._result_dispatcher(recv_obj)
1385
            self.last_receive_tstamp = time.perf_counter()
Lianmin Zheng's avatar
Lianmin Zheng committed
1386

1387
    def _handle_batch_output(
1388
1389
1390
1391
        self,
        recv_obj: Union[
            BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
        ],
1392
1393
1394
1395
    ):
        for i, rid in enumerate(recv_obj.rids):
            state = self.rid_to_state.get(rid, None)
            if state is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1396
1397
1398
                logger.error(
                    f"Received output for {rid=} but the state was deleted in TokenizerManager."
                )
1399
1400
                continue

1401
            # Build meta_info and return value
1402
1403
1404
1405
1406
1407
1408
1409
1410
            meta_info = {
                "id": rid,
                "finish_reason": recv_obj.finished_reasons[i],
                "prompt_tokens": recv_obj.prompt_tokens[i],
            }

            if getattr(state.obj, "return_logprob", False):
                self.convert_logprob_style(
                    meta_info,
1411
                    state,
1412
                    state.obj.top_logprobs_num,
1413
                    state.obj.token_ids_logprob,
1414
1415
                    state.obj.return_text_in_logprobs
                    and not self.server_args.skip_tokenizer_init,
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
                    recv_obj,
                    i,
                )

            if not isinstance(recv_obj, BatchEmbeddingOut):
                meta_info.update(
                    {
                        "completion_tokens": recv_obj.completion_tokens[i],
                        "cached_tokens": recv_obj.cached_tokens[i],
                    }
                )

1428
            if getattr(recv_obj, "output_hidden_states", None):
1429
1430
1431
                meta_info["hidden_states"] = recv_obj.output_hidden_states[i]

            if isinstance(recv_obj, BatchStrOut):
1432
                state.text += recv_obj.output_strs[i]
1433
                out_dict = {
1434
                    "text": state.text,
1435
1436
1437
                    "meta_info": meta_info,
                }
            elif isinstance(recv_obj, BatchTokenIDOut):
1438
                if self.server_args.stream_output and state.obj.stream:
1439
1440
1441
                    state.output_ids.extend(recv_obj.output_ids[i])
                    output_token_ids = state.output_ids[state.last_output_offset :]
                    state.last_output_offset = len(state.output_ids)
1442
                else:
1443
                    state.output_ids.extend(recv_obj.output_ids[i])
1444
                    output_token_ids = state.output_ids.copy()
1445

1446
                out_dict = {
1447
                    "output_ids": output_token_ids,
1448
1449
                    "meta_info": meta_info,
                }
1450
            elif isinstance(recv_obj, BatchMultimodalOut):
1451
                raise NotImplementedError("BatchMultimodalOut not implemented")
1452
1453
1454
1455
1456
1457
1458
1459
            else:
                assert isinstance(recv_obj, BatchEmbeddingOut)
                out_dict = {
                    "embedding": recv_obj.embeddings[i],
                    "meta_info": meta_info,
                }

            state.finished = recv_obj.finished_reasons[i] is not None
1460
1461
1462
1463
1464
            if state.finished:
                if self.server_args.speculative_algorithm:
                    meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
                state.finished_time = time.time()
                meta_info["e2e_latency"] = state.finished_time - state.created_time
Lianmin Zheng's avatar
Lianmin Zheng committed
1465
                del self.rid_to_state[rid]
1466
1467

            state.out_list.append(out_dict)
1468
1469
            state.event.set()

1470
            # Log metrics and dump
1471
1472
1473
1474
            if self.enable_metrics and state.obj.log_metrics:
                self.collect_metrics(state, recv_obj, i)
            if self.dump_requests_folder and state.finished and state.obj.log_metrics:
                self.dump_requests(state, out_dict)
1475
1476
            if self.crash_dump_folder and state.finished and state.obj.log_metrics:
                self.record_request_for_crash_dump(state, out_dict)
1477
1478
1479
1480

    def convert_logprob_style(
        self,
        meta_info: dict,
1481
        state: ReqState,
1482
        top_logprobs_num: int,
1483
        token_ids_logprob: List[int],
1484
1485
1486
1487
        return_text_in_logprobs: bool,
        recv_obj: BatchStrOut,
        recv_obj_index: int,
    ):
1488
1489
1490
        if recv_obj.input_token_logprobs_val is None:
            return

1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
        if len(recv_obj.input_token_logprobs_val) > 0:
            state.input_token_logprobs_val.extend(
                recv_obj.input_token_logprobs_val[recv_obj_index]
            )
            state.input_token_logprobs_idx.extend(
                recv_obj.input_token_logprobs_idx[recv_obj_index]
            )
        state.output_token_logprobs_val.extend(
            recv_obj.output_token_logprobs_val[recv_obj_index]
        )
        state.output_token_logprobs_idx.extend(
            recv_obj.output_token_logprobs_idx[recv_obj_index]
        )
1504
        meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
1505
1506
            state.input_token_logprobs_val,
            state.input_token_logprobs_idx,
1507
1508
1509
            return_text_in_logprobs,
        )
        meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
1510
1511
            state.output_token_logprobs_val,
            state.output_token_logprobs_idx,
1512
1513
1514
1515
            return_text_in_logprobs,
        )

        if top_logprobs_num > 0:
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
            if len(recv_obj.input_top_logprobs_val) > 0:
                state.input_top_logprobs_val.extend(
                    recv_obj.input_top_logprobs_val[recv_obj_index]
                )
                state.input_top_logprobs_idx.extend(
                    recv_obj.input_top_logprobs_idx[recv_obj_index]
                )
            state.output_top_logprobs_val.extend(
                recv_obj.output_top_logprobs_val[recv_obj_index]
            )
            state.output_top_logprobs_idx.extend(
                recv_obj.output_top_logprobs_idx[recv_obj_index]
            )
1529
            meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1530
1531
                state.input_top_logprobs_val,
                state.input_top_logprobs_idx,
1532
1533
1534
                return_text_in_logprobs,
            )
            meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1535
1536
                state.output_top_logprobs_val,
                state.output_top_logprobs_idx,
1537
1538
1539
                return_text_in_logprobs,
            )

1540
        if token_ids_logprob is not None:
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
            if len(recv_obj.input_token_ids_logprobs_val) > 0:
                state.input_token_ids_logprobs_val.extend(
                    recv_obj.input_token_ids_logprobs_val[recv_obj_index]
                )
                state.input_token_ids_logprobs_idx.extend(
                    recv_obj.input_token_ids_logprobs_idx[recv_obj_index]
                )
            state.output_token_ids_logprobs_val.extend(
                recv_obj.output_token_ids_logprobs_val[recv_obj_index]
            )
            state.output_token_ids_logprobs_idx.extend(
                recv_obj.output_token_ids_logprobs_idx[recv_obj_index]
            )
1554
            meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
1555
1556
                state.input_token_ids_logprobs_val,
                state.input_token_ids_logprobs_idx,
1557
1558
1559
1560
                return_text_in_logprobs,
            )
            meta_info["output_token_ids_logprobs"] = (
                self.detokenize_top_logprobs_tokens(
1561
1562
                    state.output_token_ids_logprobs_val,
                    state.output_token_ids_logprobs_idx,
1563
1564
1565
1566
                    return_text_in_logprobs,
                )
            )

1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
    def detokenize_logprob_tokens(
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
    ):
        if not decode_to_text:
            return [
                (logprob, token_id, None)
                for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
            ]
        else:
            assert self.tokenizer is not None
            token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
            return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))

    def detokenize_top_logprobs_tokens(
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
    ):
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
        ret = []
        for i in range(len(token_logprobs_val)):
            if token_logprobs_val[i]:
                ret.append(
                    self.detokenize_logprob_tokens(
                        token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
                    )
                )
            else:
                ret.append(None)
        return ret

    def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
        completion_tokens = (
            recv_obj.completion_tokens[i]
            if getattr(recv_obj, "completion_tokens", None)
            else 0
        )

1610
1611
1612
1613
        if (
            state.first_token_time == 0.0
            and self.disaggregation_mode != DisaggregationMode.PREFILL
        ):
1614
1615
            state.first_token_time = state.last_time = time.time()
            state.last_completion_tokens = completion_tokens
1616
1617
1618
1619
            self.metrics_collector.observe_time_to_first_token(
                state.first_token_time - state.created_time
            )
        else:
1620
1621
1622
1623
1624
1625
1626
            num_new_tokens = completion_tokens - state.last_completion_tokens
            if num_new_tokens:
                new_time = time.time()
                interval = new_time - state.last_time
                self.metrics_collector.observe_inter_token_latency(
                    interval,
                    num_new_tokens,
1627
                )
1628
1629
                state.last_time = new_time
                state.last_completion_tokens = completion_tokens
1630
1631

        if state.finished:
1632
1633
1634
1635
1636
1637
            has_grammar = (
                state.obj.sampling_params.get("json_schema", None)
                or state.obj.sampling_params.get("regex", None)
                or state.obj.sampling_params.get("ebnf", None)
                or state.obj.sampling_params.get("structural_tag", None)
            )
1638
            self.metrics_collector.observe_one_finished_request(
1639
1640
                recv_obj.prompt_tokens[i],
                completion_tokens,
1641
                recv_obj.cached_tokens[i],
1642
                state.finished_time - state.created_time,
1643
                has_grammar,
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
            )

    def dump_requests(self, state: ReqState, out_dict: dict):
        self.dump_request_list.append(
            (state.obj, out_dict, state.created_time, time.time())
        )

        if len(self.dump_request_list) >= self.dump_requests_threshold:
            filename = os.path.join(
                self.dump_requests_folder,
                datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
            )
            logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")

            to_dump = self.dump_request_list
            self.dump_request_list = []

1661
1662
1663
1664
1665
            to_dump_with_server_args = {
                "server_args": self.server_args,
                "requests": to_dump,
            }

1666
1667
1668
            def background_task():
                os.makedirs(self.dump_requests_folder, exist_ok=True)
                with open(filename, "wb") as f:
1669
                    pickle.dump(to_dump_with_server_args, f)
1670
1671
1672
1673

            # Schedule the task to run in the background without awaiting it
            asyncio.create_task(asyncio.to_thread(background_task))

1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
    def record_request_for_crash_dump(self, state: ReqState, out_dict: dict):
        current_time = time.time()
        self.crash_dump_request_list.append(
            (state.obj, out_dict, state.created_time, current_time)
        )
        # Remove requests older than 5 minutes based on finish time
        while (
            self.crash_dump_request_list
            and current_time - self.crash_dump_request_list[0][3] >= 300
        ):
            self.crash_dump_request_list.popleft()

Lianmin Zheng's avatar
Lianmin Zheng committed
1686
    def _handle_abort_req(self, recv_obj):
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
        state = self.rid_to_state[recv_obj.rid]
        state.finished = True
        state.out_list.append(
            {
                "text": "",
                "meta_info": {
                    "id": recv_obj.rid,
                    "finish_reason": {
                        "type": "abort",
                        "message": "Abort before prefill",
                    },
                    "prompt_tokens": 0,
                    "completion_tokens": 0,
                },
            }
        )
        state.event.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
1704

1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
    def _handle_open_session_req_output(self, recv_obj):
        self.session_futures[recv_obj.session_id].set_result(
            recv_obj.session_id if recv_obj.success else None
        )

    def _handle_update_weights_from_disk_req_output(self, recv_obj):
        if self.server_args.dp_size == 1:
            self.model_update_result.set_result(recv_obj)
        else:  # self.server_args.dp_size > 1
            self.model_update_tmp.append(recv_obj)
1715
            # set future if the all results are received
1716
1717
1718
            if len(self.model_update_tmp) == self.server_args.dp_size:
                self.model_update_result.set_result(self.model_update_tmp)

1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
    async def score_request(
        self,
        query: Optional[Union[str, List[int]]] = None,
        items: Optional[Union[str, List[str], List[List[int]]]] = None,
        label_token_ids: Optional[List[int]] = None,
        apply_softmax: bool = False,
        item_first: bool = False,
        request: Optional[Any] = None,
    ) -> List[List[float]]:
        """
        See Engine.score() for more details.
        """
        if label_token_ids is None:
            raise ValueError("label_token_ids must be provided")

        if self.tokenizer is not None:
            vocab_size = self.tokenizer.vocab_size
            for token_id in label_token_ids:
                if token_id >= vocab_size:
                    raise ValueError(
                        f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
                    )

        # Handle string or tokenized query/items
        if isinstance(query, str) and (
            isinstance(items, str)
            or (isinstance(items, list) and (not items or isinstance(items[0], str)))
        ):
            # Both query and items are text
            items_list = [items] if isinstance(items, str) else items
            if item_first:
                prompts = [f"{item}{query}" for item in items_list]
            else:
                prompts = [f"{query}{item}" for item in items_list]
            batch_request = GenerateReqInput(
                text=prompts,
                return_logprob=True,
                token_ids_logprob=label_token_ids,
                stream=False,
                sampling_params={"max_new_tokens": 1},
            )
        elif (
            isinstance(query, list)
            and isinstance(items, list)
            and items
            and isinstance(items[0], list)
        ):
            # Both query and items are token IDs
            if item_first:
                input_ids_list = [item + query for item in items]
            else:
                input_ids_list = [query + item for item in items]
            batch_request = GenerateReqInput(
                input_ids=input_ids_list,
                return_logprob=True,
                token_ids_logprob=label_token_ids,
                stream=False,
                sampling_params={"max_new_tokens": 1},
            )
        else:
            raise ValueError(
                "Invalid combination of query/items types for score_request."
            )

        results = await self.generate_request(batch_request, request).__anext__()
        scores = []

        for result in results:
            # Get logprobs for each token
            logprobs = {}
            for logprob, token_id, _ in result["meta_info"].get(
                "output_token_ids_logprobs", []
            )[0]:
                if token_id in label_token_ids:
                    logprobs[token_id] = logprob

            # Get scores in order of label_token_ids
            score_list = [
                logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
            ]

            # Apply softmax to logprobs if needed
            if apply_softmax:
                score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
            else:
                # Convert logprobs to probabilities if not using softmax
                score_list = [
                    math.exp(x) if x != float("-inf") else 0.0 for x in score_list
                ]

            scores.append(score_list)

        return scores

1813

1814
1815
1816
1817
1818
1819
1820
1821
1822
async def print_exception_wrapper(func):
    """
    Sometimes an asyncio function does not print exception.
    We do another wrapper to handle the exception.
    """
    try:
        await func()
    except Exception:
        traceback = get_exception_traceback()
1823
        logger.error(f"TokenizerManager hit an exception: {traceback}")
1824
1825
        if hasattr(func, "__self__") and isinstance(func.__self__, TokenizerManager):
            func.__self__.dump_requests_before_crash()
1826
1827
1828
1829
        kill_process_tree(os.getpid(), include_parent=True)
        sys.exit(1)


1830
class SignalHandler:
1831
    def __init__(self, tokenizer_manager: TokenizerManager):
1832
        self.tokenizer_manager = tokenizer_manager
1833

1834
    def sigterm_handler(self, signum=None, frame=None):
1835
1836
1837
        logger.warning(
            f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
        )
1838
        self.tokenizer_manager.gracefully_exit = True
1839

1840
1841
1842
1843
    def running_phase_sigquit_handler(self, signum=None, frame=None):
        logger.error(
            "Received sigquit from a child process. It usually means the child failed."
        )
1844
        self.tokenizer_manager.dump_requests_before_crash()
1845
1846
        kill_process_tree(os.getpid())

1847
1848
1849
1850
1851

T = TypeVar("T")


class _Communicator(Generic[T]):
1852
1853
    """Note: The communicator now only run up to 1 in-flight request at any time."""

1854
1855
1856
    def __init__(self, sender, fan_out: int):
        self._sender = sender
        self._fan_out = fan_out
1857
        self._result_event: Optional[asyncio.Event] = None
1858
        self._result_values: Optional[List[T]] = None
1859
        self._ready_queue: Deque[asyncio.Future] = deque()
1860
1861

    async def __call__(self, obj):
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
        ready_event = asyncio.Event()
        if self._result_event is not None or len(self._ready_queue) > 0:
            self._ready_queue.append(ready_event)
            await ready_event.wait()
            assert self._result_event is None
            assert self._result_values is None

        if obj:
            self._sender.send_pyobj(obj)

        self._result_event = asyncio.Event()
1873
        self._result_values = []
1874
        await self._result_event.wait()
1875
        result_values = self._result_values
1876
1877
1878
1879
1880
        self._result_event = self._result_values = None

        if len(self._ready_queue) > 0:
            self._ready_queue.popleft().set()

1881
1882
1883
1884
1885
        return result_values

    def handle_recv(self, recv_obj: T):
        self._result_values.append(recv_obj)
        if len(self._result_values) == self._fan_out:
1886
            self._result_event.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898


# Note: request abort handling logic
# We should handle all of the following cases correctly.
#
# | entrypoint | is_streaming | status          | abort engine    | cancel asyncio task   | rid_to_state                |
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
# | http       | yes          | waiting queue   | background task | fast api              | del in _handle_abort_req    |
# | http       | yes          | running         | background task | fast api              | del in _handle_batch_output |
# | http       | no           | waiting queue   | type 1          | type 1 exception      | del in _handle_abort_req    |
# | http       | no           | running         | type 3          | type 3 exception      | del in _handle_batch_output |
#