tokenizer_manager.py 99.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
"""TokenizerManager is a process that tokenizes the text."""
15

Lianmin Zheng's avatar
Lianmin Zheng committed
16
import asyncio
17
18
import copy
import dataclasses
19
import json
20
import logging
21
import math
Lianmin Zheng's avatar
Lianmin Zheng committed
22
import os
23
import pickle
24
25
import signal
import sys
26
import threading
27
import time
28
import uuid
29
from collections import deque
fzyzcjy's avatar
fzyzcjy committed
30
from contextlib import nullcontext
31
from datetime import datetime
32
from enum import Enum
33
from http import HTTPStatus
34
35
36
37
38
39
40
41
42
43
44
45
from typing import (
    Any,
    Awaitable,
    Deque,
    Dict,
    Generic,
    List,
    Optional,
    Tuple,
    TypeVar,
    Union,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
46

47
import fastapi
48
import torch
Lianmin Zheng's avatar
Lianmin Zheng committed
49
50
51
import uvloop
import zmq
import zmq.asyncio
52
from fastapi import BackgroundTasks
Liangsheng Yin's avatar
Liangsheng Yin committed
53

54
from sglang.srt.aio_rwlock import RWLock
55
from sglang.srt.configs.model_config import ModelConfig
56
57
58
59
60
61
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
    KVClassType,
    TransferBackend,
    get_kv_class,
)
xm:D's avatar
xm:D committed
62
63
64
65
66
from sglang.srt.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
67
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
Lianmin Zheng's avatar
Lianmin Zheng committed
68
from sglang.srt.managers.io_struct import (
69
    AbortReq,
70
    BatchEmbeddingOut,
71
    BatchMultimodalOut,
Lianmin Zheng's avatar
Lianmin Zheng committed
72
    BatchStrOut,
73
    BatchTokenIDOut,
74
    CloseSessionReqInput,
75
    ConfigureLoggingReq,
76
    EmbeddingReqInput,
77
    ExpertDistributionReq,
78
    ExpertDistributionReqOutput,
79
80
    FlushCacheReqInput,
    FlushCacheReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
81
    GenerateReqInput,
82
83
    GetInternalStateReq,
    GetInternalStateReqOutput,
84
85
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
86
    HealthCheckOutput,
87
88
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
89
90
91
    LoadLoRAAdapterReqInput,
    LoadLoRAAdapterReqOutput,
    LoRAUpdateResult,
92
    MultiTokenizerRegisterReq,
93
94
    OpenSessionReqInput,
    OpenSessionReqOutput,
95
    ProfileReq,
96
97
    ProfileReqOutput,
    ProfileReqType,
98
99
100
101
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
102
    SessionParams,
103
104
    SetInternalStateReq,
    SetInternalStateReqOutput,
105
106
    SlowDownReqInput,
    SlowDownReqOutput,
107
108
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
109
110
    UnloadLoRAAdapterReqInput,
    UnloadLoRAAdapterReqOutput,
Chayenne's avatar
Chayenne committed
111
112
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
113
114
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
115
116
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
Lianmin Zheng's avatar
Lianmin Zheng committed
117
)
Mick's avatar
Mick committed
118
from sglang.srt.managers.mm_utils import TensorTransportMode
119
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
120
from sglang.srt.managers.scheduler import is_health_check_generate_req
fzyzcjy's avatar
fzyzcjy committed
121
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
122
123
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
Lianmin Zheng's avatar
Lianmin Zheng committed
124
from sglang.srt.server_args import PortArgs, ServerArgs
125
126
from sglang.srt.utils import (
    dataclass_to_string_truncated,
127
    get_bool_env_var,
128
129
    get_origin_rid,
    get_workerids_from_rids,
130
131
132
    get_zmq_socket,
    kill_process_tree,
)
133
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
Lianmin Zheng's avatar
Lianmin Zheng committed
134
135
136

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

137
138
logger = logging.getLogger(__name__)

Lianmin Zheng's avatar
Lianmin Zheng committed
139

140
141
142
143
@dataclasses.dataclass
class ReqState:
    """Store the state a request."""

144
    out_list: List[Dict[Any, Any]]
145
146
    finished: bool
    event: asyncio.Event
147
    obj: Union[GenerateReqInput, EmbeddingReqInput]
148
149
150

    # For metrics
    created_time: float
151
152
153
154
    finished_time: float = 0.0
    first_token_time: float = 0.0
    last_time: float = 0.0
    last_completion_tokens: int = 1
155
156
157

    # For streaming output
    last_output_offset: int = 0
158

159
    # For incremental state update.
160
    # TODO(lianmin): do not initialize some lists if not needed.
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    text: str = ""
    output_ids: List[int] = dataclasses.field(default_factory=list)
    input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
    input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
    output_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
    output_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
    input_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
    input_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
    output_top_logprobs_val: List[List[float]] = dataclasses.field(default_factory=list)
    output_top_logprobs_idx: List[List[int]] = dataclasses.field(default_factory=list)
    input_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
    input_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
    output_token_ids_logprobs_val: List = dataclasses.field(default_factory=list)
    output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
175
176


177
178
179
_global_tokenizer_worker_num = 1


180
181
class TokenizerManager:
    """TokenizerManager is a process that tokenizes the text."""
182

Lianmin Zheng's avatar
Lianmin Zheng committed
183
184
185
186
    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
187
        is_main: Optional[bool] = True,
Lianmin Zheng's avatar
Lianmin Zheng committed
188
    ):
189
        # Parse args
Liangsheng Yin's avatar
Liangsheng Yin committed
190
        self.server_args = server_args
191
        self.enable_metrics = server_args.enable_metrics
192
        self.log_requests = server_args.log_requests
193
        self.log_requests_level = server_args.log_requests_level
194
195
196
197
198
        self.preferred_sampling_params = (
            json.loads(server_args.preferred_sampling_params)
            if server_args.preferred_sampling_params
            else None
        )
199
        self.crash_dump_folder = server_args.crash_dump_folder
Lianmin Zheng's avatar
Lianmin Zheng committed
200

201
202
203
        self.is_main = is_main
        self.worker_id = os.getpid()

204
        # Read model args
Lianmin Zheng's avatar
Lianmin Zheng committed
205
        self.model_path = server_args.model_path
206
        self.served_model_name = server_args.served_model_name
207
        self.model_config = ModelConfig.from_server_args(server_args)
208
        self.is_generation = self.model_config.is_generation
209
        self.is_image_gen = self.model_config.is_image_gen
210
211
        self.context_len = self.model_config.context_len
        self.image_token_id = self.model_config.image_token_id
212
        self.max_req_input_len = None  # Will be set later in engine.py
213

214
        if self.model_config.is_multimodal:
Mick's avatar
Mick committed
215
            import_processors()
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
            try:
                _processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                    use_fast=not server_args.disable_fast_image_processor,
                )
            except ValueError as e:
                error_message = str(e)
                if "does not have a slow version" in error_message:
                    logger.info(
                        f"Processor {server_args.tokenizer_path} does not have a slow version. Automatically use fast version"
                    )
                    _processor = get_processor(
                        server_args.tokenizer_path,
                        tokenizer_mode=server_args.tokenizer_mode,
                        trust_remote_code=server_args.trust_remote_code,
                        revision=server_args.revision,
                        use_fast=True,
                    )
                else:
                    raise e
Mick's avatar
Mick committed
239
            transport_mode = _determine_tensor_transport_mode(self.server_args)
240
241

            # We want to parallelize the image pre-processing so we create an executor for it
Mick's avatar
Mick committed
242
            # We create mm_processor for any skip_tokenizer_init to make sure we still encode
243
            # images even with skip_tokenizer_init=False.
Mick's avatar
Mick committed
244
            self.mm_processor = get_mm_processor(
Mick's avatar
Mick committed
245
                self.model_config.hf_config, server_args, _processor, transport_mode
246
247
248
249
250
251
            )

            if server_args.skip_tokenizer_init:
                self.tokenizer = self.processor = None
            else:
                self.processor = _processor
xm:D's avatar
xm:D committed
252
                self.tokenizer = get_tokenizer_from_processor(self.processor)
253
                os.environ["TOKENIZERS_PARALLELISM"] = "false"
254
        else:
255
            self.mm_processor = self.processor = None
256

257
            if server_args.skip_tokenizer_init:
258
                self.tokenizer = None
259
260
261
262
263
264
265
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
266

267
        # Init inter-process communication
268
        context = zmq.asyncio.Context(3)
269
270
271
        self.recv_from_detokenizer = get_zmq_socket(
            context, zmq.PULL, port_args.tokenizer_ipc_name, True
        )
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
        global _global_tokenizer_worker_num
        _global_tokenizer_worker_num = server_args.tokenizer_worker_num
        if server_args.tokenizer_worker_num > 1:
            self.tokenizer_ipc_name = port_args.tokenizer_ipc_name
            if self.is_main:
                self.send_to_scheduler = get_zmq_socket(
                    context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
                )
                self.receive_from_worker = get_zmq_socket(
                    context, zmq.PULL, port_args.tokenizer_worker_ipc_name, True
                )
                self._loop = asyncio.new_event_loop()
                self._thread = threading.Thread(target=self._run_loop, daemon=True)
                self._thread.start()
                self._task = asyncio.run_coroutine_threadsafe(
                    self.router_worker_obj(), self._loop
                )
                # Start handle_loop simultaneously
                self._handle_task = asyncio.run_coroutine_threadsafe(
                    print_exception_wrapper(self.handle_loop), self._loop
                )

            else:
                # actual send to main receiver_from_worker
                self.send_to_scheduler = get_zmq_socket(
                    context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
                )
        else:
            self.send_to_scheduler = get_zmq_socket(
                context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
            )
303

304
        # Request states
305
        self.no_create_loop = False
306
        self.rid_to_state: Dict[str, ReqState] = {}
307
308
309
        self.asyncio_tasks = set()

        # Health check
310
        self.health_check_failed = False
311
312
        self.gracefully_exit = False
        self.last_receive_tstamp = 0
313
        self.server_status = ServerStatus.Starting
314
315

        # Dumping
316
317
318
        self.dump_requests_folder = ""  # By default do not dump
        self.dump_requests_threshold = 1000
        self.dump_request_list: List[Tuple] = []
319
        self.log_request_metadata = self.get_log_request_metadata()
320
321
322
323
        self.crash_dump_request_list: deque[Tuple] = deque()
        self.crash_dump_performed = False  # Flag to ensure dump is only called once

        # Session
324
        self.session_futures = {}  # session_id -> asyncio event
Lianmin Zheng's avatar
Lianmin Zheng committed
325

326
        # Weight updates
327
328
329
330
331
        # The event to notify the weight sync is finished.
        self.model_update_lock = RWLock()
        self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
            None
        )
332
333
        self._is_updating = False
        self._is_updating_cond = asyncio.Condition()
334

335
336
337
338
339
340
        # LoRA
        # Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
        # The registry dynamically updates as adapters are loaded / unloaded during runtime. It
        # serves as the source of truth for available adapters and maps user-friendly LoRA names
        # to internally used unique LoRA IDs.
        self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
341
342
343
344
345
        # Lock to serialize LoRA update operations.
        # Please note that, unlike `model_update_lock`, this does not block inference, allowing
        # LoRA updates and inference to overlap.
        self.lora_update_lock = asyncio.Lock()

346
        # For PD disaggregtion
347
348
349
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
350
        self.disaggregation_transfer_backend = TransferBackend(
351
352
353
354
355
            self.server_args.disaggregation_transfer_backend
        )
        # Start kv boostrap server on prefill
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            # only start bootstrap server on prefill tm
356
357
358
359
360
361
362
363
364
365
366
367
368
369
            if self.is_main:
                kv_bootstrap_server_class = get_kv_class(
                    self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
                )
                self.bootstrap_server = kv_bootstrap_server_class(
                    self.server_args.disaggregation_bootstrap_port
                )
                is_create_store = (
                    self.server_args.node_rank == 0
                    and self.server_args.disaggregation_transfer_backend == "ascend"
                )
                if is_create_store:
                    try:
                        from mf_adapter import create_config_store
370

371
372
373
374
375
376
                        ascend_url = os.getenv("ASCEND_MF_STORE_URL")
                        create_config_store(ascend_url)
                    except Exception as e:
                        error_message = f"Failed create mf store, invalid ascend_url."
                        error_message += f" With exception {e}"
                        raise error_message
377

378
379
380
        # For load balancing
        self.current_load = 0
        self.current_load_lock = asyncio.Lock()
381
382
383
384
385
386

        # Metrics
        if self.enable_metrics:
            self.metrics_collector = TokenizerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
387
                    # TODO: Add lora name/path in the future,
388
                },
389
390
391
392
                bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
                bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
                bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
                collect_tokens_histogram=self.server_args.collect_tokens_histogram,
393
394
395
            )

        # Communicators
396
397
398
399
400
401
        self.init_weights_update_group_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.update_weights_from_distributed_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
402
403
404
        self.update_weights_from_tensor_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
405
406
407
        self.get_weights_by_name_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
408
409
410
411
412
413
        self.release_memory_occupation_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
        self.resume_memory_occupation_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
414
415
416
        self.slow_down_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
417
418
419
        self.flush_cache_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
420
        self.profile_communicator = _Communicator(
421
422
423
424
425
            self.send_to_scheduler, server_args.dp_size
        )
        self.get_internal_state_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
426
427
428
        self.set_internal_state_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
429
430
431
        self.expert_distribution_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
432
433
434
        self.update_lora_adapter_communicator = _Communicator(
            self.send_to_scheduler, server_args.dp_size
        )
435

436
        self._result_dispatcher = TypeBasedDispatcher(
437
            [
438
                (
439
440
441
442
443
444
                    (
                        BatchStrOut,
                        BatchEmbeddingOut,
                        BatchTokenIDOut,
                        BatchMultimodalOut,
                    ),
445
                    self._handle_batch_output,
446
                ),
Lianmin Zheng's avatar
Lianmin Zheng committed
447
                (AbortReq, self._handle_abort_req),
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
                (OpenSessionReqOutput, self._handle_open_session_req_output),
                (
                    UpdateWeightFromDiskReqOutput,
                    self._handle_update_weights_from_disk_req_output,
                ),
                (
                    InitWeightsUpdateGroupReqOutput,
                    self.init_weights_update_group_communicator.handle_recv,
                ),
                (
                    UpdateWeightsFromDistributedReqOutput,
                    self.update_weights_from_distributed_communicator.handle_recv,
                ),
                (
                    UpdateWeightsFromTensorReqOutput,
                    self.update_weights_from_tensor_communicator.handle_recv,
                ),
                (
                    GetWeightsByNameReqOutput,
                    self.get_weights_by_name_communicator.handle_recv,
                ),
                (
                    ReleaseMemoryOccupationReqOutput,
                    self.release_memory_occupation_communicator.handle_recv,
                ),
                (
                    ResumeMemoryOccupationReqOutput,
                    self.resume_memory_occupation_communicator.handle_recv,
                ),
477
478
479
480
                (
                    SlowDownReqOutput,
                    self.slow_down_communicator.handle_recv,
                ),
481
482
483
484
                (
                    FlushCacheReqOutput,
                    self.flush_cache_communicator.handle_recv,
                ),
485
486
                (
                    ProfileReqOutput,
487
                    self.profile_communicator.handle_recv,
488
489
490
491
492
                ),
                (
                    GetInternalStateReqOutput,
                    self.get_internal_state_communicator.handle_recv,
                ),
493
494
495
496
                (
                    SetInternalStateReqOutput,
                    self.set_internal_state_communicator.handle_recv,
                ),
497
498
499
500
                (
                    ExpertDistributionReqOutput,
                    self.expert_distribution_communicator.handle_recv,
                ),
501
502
503
504
                (
                    LoRAUpdateResult,
                    self.update_lora_adapter_communicator.handle_recv,
                ),
505
                (HealthCheckOutput, lambda x: None),
506
507
508
            ]
        )

509
510
511
512
513
514
515
516
    def _run_loop(self):
        self._loop.run_forever()

    async def router_worker_obj(self):
        while True:
            recv_obj = await self.receive_from_worker.recv_pyobj()
            await self.send_to_scheduler.send_pyobj(recv_obj)

517
    async def generate_request(
518
        self,
519
        obj: Union[GenerateReqInput, EmbeddingReqInput],
520
        request: Optional[fastapi.Request] = None,
521
    ):
522
        created_time = time.time()
523
        self.auto_create_handle_loop()
524
        obj.normalize_batch_and_arguments()
525

526
527
        async with self._is_updating_cond:
            await self._is_updating_cond.wait_for(lambda: not self._is_updating)
528

529
530
531
532
533
534
535
536
537
        if self.server_args.tokenizer_worker_num > 1:
            # Modify rid, add worker_id
            if isinstance(obj.rid, list):
                # If it's an array, add worker_id prefix to each element
                obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid]
            else:
                # If it's a single value, add worker_id prefix
                obj.rid = f"{self.worker_id}_{obj.rid}"

538
        if self.log_requests:
539
            max_length, skip_names, _ = self.log_request_metadata
540
            logger.info(
541
                f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
542
543
            )

544
        async with self.model_update_lock.reader_lock:
545
            if obj.is_single:
546
                tokenized_obj = await self._tokenize_one_request(obj)
547
548
                state = self._send_one_request(obj, tokenized_obj, created_time)
                async for response in self._wait_one_response(obj, state, request):
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
                    yield response
            else:
                async for response in self._handle_batch_request(
                    obj, request, created_time
                ):
                    yield response

    async def _tokenize_one_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
    ):
        """Tokenize one request."""
        # Tokenize
        input_embeds = None
        input_text = obj.text
woodx's avatar
woodx committed
564
565
566
567
        token_type_ids = None
        is_cross_encoder_request = (
            isinstance(obj, EmbeddingReqInput) and obj.is_cross_encoder_request
        )
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
        if obj.input_embeds is not None:
            if not self.server_args.disable_radix_cache:
                raise ValueError(
                    "input_embeds is provided while disable_radix_cache is False. "
                    "Please add `--disable-radix-cache` when you launch the server "
                    "if you want to use input_embeds as inputs."
                )
            input_embeds = obj.input_embeds
            input_ids = obj.input_ids
        elif obj.input_ids is not None:
            input_ids = obj.input_ids
        else:
            if self.tokenizer is None:
                raise ValueError(
                    "The engine initialized with skip_tokenizer_init=True cannot "
                    "accept text prompts. Please provide input_ids or re-initialize "
                    "the engine with skip_tokenizer_init=False."
                )
woodx's avatar
woodx committed
586
587
588
589
590
591
592
593
            encoded = self.tokenizer(
                input_text, return_token_type_ids=is_cross_encoder_request
            )

            input_ids = encoded["input_ids"]
            if is_cross_encoder_request:
                input_ids = encoded["input_ids"][0]
                token_type_ids = encoded.get("token_type_ids", [None])[0]
594

595
        if self.mm_processor and obj.contains_mm_input():
596
597
598
599
600
            if not isinstance(obj.image_data, list):
                obj.image_data = [obj.image_data]
            if not isinstance(obj.audio_data, list):
                obj.audio_data = [obj.audio_data]
            mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
601
                image_data=obj.image_data,
602
                audio_data=obj.audio_data,
603
604
605
606
                input_text=input_text or input_ids,
                request_obj=obj,
                max_req_input_len=self.max_req_input_len,
            )
607
608
            if mm_inputs and "input_ids" in mm_inputs:
                input_ids = mm_inputs["input_ids"]
609
        else:
610
            mm_inputs = None
611

612
        if self.server_args.enable_lora and obj.lora_path:
613
614
            # Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
            # `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
615
            obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
616

617
        self._validate_one_request(obj, input_ids)
618
        return self._create_tokenized_object(
619
            obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
620
621
        )

622
    def _validate_one_request(
623
624
625
        self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
    ) -> None:
        """Validates that the input token count and the requested token count doesn't exceed the model's context length."""
626
627

        input_token_num = len(input_ids) if input_ids is not None else 0
628
        # Check if input alone exceeds context length
629
630
631
632
633
634
        if input_token_num >= self.context_len:
            raise ValueError(
                f"The input ({input_token_num} tokens) is longer than the "
                f"model's context length ({self.context_len} tokens)."
            )

635
636
637
638
639
640
        if isinstance(obj, EmbeddingReqInput) and self.is_generation:
            raise ValueError(
                "This model does not appear to be an embedding model by default. "
                "Please add `--is-embedding` when launching the server or try another model."
            )

641
642
        # Check total tokens (input + max_new_tokens)
        max_new_tokens = obj.sampling_params.get("max_new_tokens")
643
        if (
644
645
            max_new_tokens is not None
            and (max_new_tokens + input_token_num) >= self.context_len
646
        ):
647
648
            total_tokens = max_new_tokens + input_token_num
            error_msg = (
649
                f"Requested token count exceeds the model's maximum context length "
650
                f"of {self.context_len} tokens. You requested a total of {total_tokens} "
651
                f"tokens: {input_token_num} tokens from the input messages and "
652
653
654
655
656
                f"{max_new_tokens} tokens for the completion. Please reduce the number "
                f"of tokens in the input messages or the completion to fit within the limit."
            )
            raise ValueError(error_msg)

657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
        if isinstance(obj, GenerateReqInput):
            if (
                obj.return_hidden_states
                and not self.server_args.enable_return_hidden_states
            ):
                raise ValueError(
                    "The server is not configured to return the hidden states. "
                    "Please set `--enable-return-hidden-states` to enable this feature."
                )
            if (
                obj.custom_logit_processor
                and not self.server_args.enable_custom_logit_processor
            ):
                raise ValueError(
                    "The server is not configured to enable custom logit processor. "
                    "Please set `--enable-custom-logits-processor` to enable this feature."
                )

675
676
677
678
679
680
681
682
    def _validate_input_ids_in_vocab(
        self, input_ids: List[int], vocab_size: int
    ) -> None:
        if any(id >= vocab_size for id in input_ids):
            raise ValueError(
                f"The input_ids {input_ids} contains values greater than the vocab size ({vocab_size})."
            )

683
684
685
686
687
688
    def _create_tokenized_object(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        input_text: str,
        input_ids: List[int],
        input_embeds: Optional[Union[List[float], None]] = None,
689
        mm_inputs: Optional[Dict] = None,
woodx's avatar
woodx committed
690
        token_type_ids: Optional[List[int]] = None,
691
692
    ) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
        """Create a tokenized request object from common parameters."""
693
694
695
696
697
698
699
700
        # Parse sampling parameters
        # Note: if there are preferred sampling params, we use them if they are not
        # explicitly passed in sampling_params
        if self.preferred_sampling_params:
            sampling_kwargs = {**self.preferred_sampling_params, **obj.sampling_params}
        else:
            sampling_kwargs = obj.sampling_params
        sampling_params = SamplingParams(**sampling_kwargs)
701
        sampling_params.normalize(self.tokenizer)
702
        sampling_params.verify(self.model_config.vocab_size)
703
704
705

        # Build return object
        if isinstance(obj, GenerateReqInput):
706
707
708
709
            session_params = (
                SessionParams(**obj.session_params) if obj.session_params else None
            )

710
711
712
713
            tokenized_obj = TokenizedGenerateReqInput(
                obj.rid,
                input_text,
                input_ids,
714
                mm_inputs,
715
                sampling_params,
716
717
718
719
                obj.return_logprob,
                obj.logprob_start_len,
                obj.top_logprobs_num,
                obj.token_ids_logprob,
720
                obj.stream,
721
                bootstrap_host=obj.bootstrap_host,
722
                bootstrap_port=obj.bootstrap_port,
723
                bootstrap_room=obj.bootstrap_room,
724
                lora_id=obj.lora_id,
725
726
727
                input_embeds=input_embeds,
                session_params=session_params,
                custom_logit_processor=obj.custom_logit_processor,
728
                return_hidden_states=obj.return_hidden_states,
729
                data_parallel_rank=obj.data_parallel_rank,
730
731
732
733
734
735
            )
        elif isinstance(obj, EmbeddingReqInput):
            tokenized_obj = TokenizedEmbeddingReqInput(
                obj.rid,
                input_text,
                input_ids,
736
                mm_inputs,
woodx's avatar
woodx committed
737
                token_type_ids,
738
739
740
741
742
                sampling_params,
            )

        return tokenized_obj

743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
    async def _batch_tokenize_and_process(
        self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
    ) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]:
        """Handle batch tokenization for text inputs only."""
        logger.debug(f"Starting batch tokenization for {batch_size} text requests")

        # Collect requests and texts
        requests = [obj[i] for i in range(batch_size)]
        texts = [req.text for req in requests]

        # Batch tokenize all texts
        encoded = self.tokenizer(texts)
        input_ids_list = encoded["input_ids"]

        # Process all requests
        tokenized_objs = []
        for i, req in enumerate(requests):
            self._validate_token_len(obj[i], input_ids_list[i])
            tokenized_objs.append(
                self._create_tokenized_object(
                    req, req.text, input_ids_list[i], None, None
                )
            )
        logger.debug(f"Completed batch processing for {batch_size} requests")
        return tokenized_objs

    def _validate_batch_tokenization_constraints(
        self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
    ) -> None:
        """Validate constraints for batch tokenization processing."""
        for i in range(batch_size):
774
            if self.is_generation and obj[i].contains_mm_input():
775
                raise ValueError(
776
                    "For multimodal input processing do not set `enable_tokenizer_batch_encode`."
777
778
779
780
781
782
783
784
785
786
                )
            if obj[i].input_ids is not None:
                raise ValueError(
                    "Batch tokenization is not needed for pre-tokenized input_ids. Do not set `enable_tokenizer_batch_encode`."
                )
            if obj[i].input_embeds is not None:
                raise ValueError(
                    "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
                )

787
788
789
790
791
792
    def _send_one_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
        created_time: Optional[float] = None,
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
793
        self.send_to_scheduler.send_pyobj(tokenized_obj)
794
        state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
795
        self.rid_to_state[obj.rid] = state
796
        return state
797
798
799
800

    async def _wait_one_response(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
801
        state: ReqState,
802
803
804
805
806
807
808
        request: Optional[fastapi.Request] = None,
    ):
        """Wait for the response of one request."""
        while True:
            try:
                await asyncio.wait_for(state.event.wait(), timeout=4)
            except asyncio.TimeoutError:
809
810
811
812
813
                if (
                    request is not None
                    and not obj.background
                    and await request.is_disconnected()
                ):
Lianmin Zheng's avatar
Lianmin Zheng committed
814
                    # Abort the request for disconnected requests (non-streaming, waiting queue)
815
                    self.abort_request(obj.rid)
Lianmin Zheng's avatar
Lianmin Zheng committed
816
                    # Use exception to kill the whole call stack and asyncio task
817
                    raise ValueError(
Lianmin Zheng's avatar
Lianmin Zheng committed
818
                        f"Request is disconnected from the client side (type 1). Abort request {obj.rid=}"
819
                    )
820
821
822
823
824
825
826
                continue

            out = state.out_list[-1]

            state.out_list = []
            if state.finished:
                if self.log_requests:
827
828
829
830
831
                    max_length, skip_names, out_skip_names = self.log_request_metadata
                    if self.model_config.is_multimodal_gen:
                        msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
                    else:
                        msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
832
833
                    logger.info(msg)

834
835
                # Mark ongoing LoRA request as finished.
                if self.server_args.enable_lora and obj.lora_path:
836
                    await self.lora_registry.release(obj.lora_id)
837

838
839
840
841
842
843
844
845
846
                # Check if this was an abort/error created by scheduler
                if isinstance(out["meta_info"].get("finish_reason"), dict):
                    finish_reason = out["meta_info"]["finish_reason"]
                    if (
                        finish_reason.get("type") == "abort"
                        and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST
                    ):
                        raise ValueError(finish_reason["message"])

847
848
849
850
851
852
853
854
855
856
857
858
859
                    if (
                        finish_reason.get("type") == "abort"
                        and finish_reason.get("status_code")
                        == HTTPStatus.SERVICE_UNAVAILABLE
                    ):
                        # This is an abort request initiated by scheduler.
                        # Delete the key to prevent resending abort request to the scheduler and
                        # to ensure aborted request state is cleaned up.
                        del self.rid_to_state[state.obj.rid]
                        raise fastapi.HTTPException(
                            status_code=finish_reason["status_code"],
                            detail=finish_reason["message"],
                        )
860
861
862
863
864
865
866
867
                yield out
                break

            state.event.clear()

            if obj.stream:
                yield out
            else:
868
869
870
871
872
                if (
                    request is not None
                    and not obj.background
                    and await request.is_disconnected()
                ):
Lianmin Zheng's avatar
Lianmin Zheng committed
873
                    # Abort the request for disconnected requests (non-streaming, running)
874
                    self.abort_request(obj.rid)
Lianmin Zheng's avatar
Lianmin Zheng committed
875
                    # Use exception to kill the whole call stack and asyncio task
876
                    raise ValueError(
Lianmin Zheng's avatar
Lianmin Zheng committed
877
                        f"Request is disconnected from the client side (type 3). Abort request {obj.rid=}"
878
                    )
879
880
881
882
883
884
885
886
887
888
889
890

    async def _handle_batch_request(
        self,
        obj: Union[GenerateReqInput, EmbeddingReqInput],
        request: Optional[fastapi.Request] = None,
        created_time: Optional[float] = None,
    ):
        batch_size = obj.batch_size

        generators = []
        rids = []
        if getattr(obj, "parallel_sample_num", 1) == 1:
891
892
893
894
895
896
897
898
            if self.server_args.enable_tokenizer_batch_encode:
                # Validate batch tokenization constraints
                self._validate_batch_tokenization_constraints(batch_size, obj)

                tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)

                for i, tokenized_obj in enumerate(tokenized_objs):
                    tmp_obj = obj[i]
899
900
                    state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, state, request))
901
902
903
                    rids.append(tmp_obj.rid)
            else:
                # Sequential tokenization and processing
fzyzcjy's avatar
fzyzcjy committed
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
                with (
                    input_blocker_guard_region(send_to_scheduler=self.send_to_scheduler)
                    if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
                    else nullcontext()
                ):
                    for i in range(batch_size):
                        tmp_obj = obj[i]
                        tokenized_obj = await self._tokenize_one_request(tmp_obj)
                        state = self._send_one_request(
                            tmp_obj, tokenized_obj, created_time
                        )
                        generators.append(
                            self._wait_one_response(tmp_obj, state, request)
                        )
                        rids.append(tmp_obj.rid)
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
        else:
            # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
            if batch_size > 128:
                logger.warning(
                    "Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
                    "The performance might be better if you just duplicate the requests n times or use "
                    "many threads to send them one by one with parallel sampling (n > 1)."
                )

            # Tokenize all requests
            objs = [obj[i] for i in range(batch_size)]
            tokenized_objs = await asyncio.gather(
                *(self._tokenize_one_request(obj) for obj in objs)
            )

            # Cache the common prefix for parallel sampling
            for i in range(batch_size):
                tmp_obj = copy.copy(objs[i])
                tokenized_obj = copy.copy(tokenized_objs[i])
                tokenized_obj.rid = tmp_obj.regenerate_rid()
                tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
                tokenized_obj.sampling_params.max_new_tokens = 0
                tokenized_obj.stream = False
942
943
                state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
                await self._wait_one_response(tmp_obj, state, request).__anext__()
944
945
946
947
948
949
950

            # Expand requests, assign new rids for them, and send them
            for i in range(batch_size):
                for _ in range(obj.parallel_sample_num):
                    tmp_obj = copy.copy(objs[i])
                    tokenized_obj = copy.copy(tokenized_objs[i])
                    tokenized_obj.rid = tmp_obj.regenerate_rid()
951
952
                    state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
                    generators.append(self._wait_one_response(tmp_obj, state, request))
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
                    rids.append(tmp_obj.rid)

        # Wait for all requests
        is_stream = hasattr(obj, "stream") and obj.stream
        if not is_stream:
            outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
            yield outputs
        else:
            rid_to_index = {rid: i for i, rid in enumerate(rids)}
            task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
            while task_map:
                done, _ = await asyncio.wait(
                    task_map.keys(), return_when=asyncio.FIRST_COMPLETED
                )

                for task in done:
                    gen = task_map.pop(task)
                    try:
                        result = task.result()
                        result["index"] = rid_to_index[result["meta_info"]["id"]]
                        yield result
                        new_task = asyncio.create_task(gen.__anext__())
                        task_map[new_task] = gen
                    except StopAsyncIteration:
                        pass
978

979
    async def flush_cache(self) -> FlushCacheReqOutput:
Lianmin Zheng's avatar
Lianmin Zheng committed
980
        return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
Liangsheng Yin's avatar
Liangsheng Yin committed
981

982
983
    def abort_request(self, rid: str = "", abort_all: bool = False):
        if not abort_all and rid not in self.rid_to_state:
984
            return
985
        req = AbortReq(rid, abort_all)
986
        self.send_to_scheduler.send_pyobj(req)
987

988
989
990
        if self.enable_metrics:
            self.metrics_collector.observe_one_aborted_request()

991
992
993
    async def start_profile(
        self,
        output_dir: Optional[str] = None,
994
        start_step: Optional[int] = None,
995
996
        num_steps: Optional[int] = None,
        activities: Optional[List[str]] = None,
997
998
        with_stack: Optional[bool] = None,
        record_shapes: Optional[bool] = None,
999
        profile_by_stage: bool = False,
1000
    ):
1001
        self.auto_create_handle_loop()
1002
1003
        env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
        with_stack = False if with_stack is False or env_with_stack is False else True
1004
1005
1006
        req = ProfileReq(
            type=ProfileReqType.START_PROFILE,
            output_dir=output_dir,
1007
            start_step=start_step,
1008
1009
            num_steps=num_steps,
            activities=activities,
1010
1011
            with_stack=with_stack,
            record_shapes=record_shapes,
1012
            profile_by_stage=profile_by_stage,
1013
            profile_id=str(time.time()),
1014
        )
1015
1016
1017
        return await self._execute_profile(req)

    async def stop_profile(self):
1018
        self.auto_create_handle_loop()
1019
1020
1021
1022
1023
        req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
        return await self._execute_profile(req)

    async def _execute_profile(self, req: ProfileReq):
        result = (await self.profile_communicator(req))[0]
1024
1025
1026
        if not result.success:
            raise RuntimeError(result.message)
        return result
1027

1028
    async def start_expert_distribution_record(self):
1029
        self.auto_create_handle_loop()
1030
        await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
1031

1032
    async def stop_expert_distribution_record(self):
1033
        self.auto_create_handle_loop()
1034
        await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
1035

1036
    async def dump_expert_distribution_record(self):
1037
        self.auto_create_handle_loop()
1038
        await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
1039

1040
    async def pause_generation(self):
1041
1042
        async with self._is_updating_cond:
            self._is_updating = True
1043
1044
1045
            self.abort_request(abort_all=True)

    async def continue_generation(self):
1046
1047
1048
        async with self._is_updating_cond:
            self._is_updating = False
            self._is_updating_cond.notify_all()
1049

Chayenne's avatar
Chayenne committed
1050
1051
1052
1053
    async def update_weights_from_disk(
        self,
        obj: UpdateWeightFromDiskReqInput,
        request: Optional[fastapi.Request] = None,
1054
    ) -> Tuple[bool, str]:
1055
        self.auto_create_handle_loop()
1056
1057
1058
1059

        # default the load format to the server_args
        if obj.load_format is None:
            obj.load_format = self.server_args.load_format
1060
        logger.info("Start update_weights. Load format=%s", obj.load_format)
1061

1062
1063
1064
        if obj.abort_all_requests:
            self.abort_request(abort_all=True)

1065
        if True:  # Keep this redundant check to simplify some internal code sync
1066
1067
1068
1069
            # Hold the lock if it is not async. This means that weight sync
            # cannot run while requests are in progress.
            async with self.model_update_lock.writer_lock:
                return await self._wait_for_model_update_from_disk(obj)
1070

1071
1072
    async def _wait_for_model_update_from_disk(
        self, obj: UpdateWeightFromDiskReqInput
1073
    ) -> Tuple[bool, str]:
1074
1075
1076
1077
1078
1079
1080
1081
1082
        self.send_to_scheduler.send_pyobj(obj)
        self.model_update_result = asyncio.Future()
        if self.server_args.dp_size == 1:
            result = await self.model_update_result
            if result.success:
                self.served_model_name = obj.model_path
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
1083
            return result.success, result.message, result.num_paused_requests
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
        else:  # self.server_args.dp_size > 1
            self.model_update_tmp = []
            result = await self.model_update_result

            all_success = all([r.success for r in result])
            if all_success is True:
                self.server_args.model_path = obj.model_path
                self.server_args.load_format = obj.load_format
                self.model_path = obj.model_path
            all_message = [r.message for r in result]
            all_message = " | ".join(all_message)
1095
1096
            all_paused_requests = [r.num_paused_requests for r in result]
            return all_success, all_message, all_paused_requests
1097

1098
1099
1100
1101
    async def init_weights_update_group(
        self,
        obj: InitWeightsUpdateGroupReqInput,
        request: Optional[fastapi.Request] = None,
1102
    ) -> Tuple[bool, str]:
1103
        self.auto_create_handle_loop()
1104
1105
1106
        assert (
            self.server_args.dp_size == 1
        ), "dp_size must be 1 for init parameter update group"
1107
        result = (await self.init_weights_update_group_communicator(obj))[0]
1108
1109
1110
1111
1112
1113
        return result.success, result.message

    async def update_weights_from_distributed(
        self,
        obj: UpdateWeightsFromDistributedReqInput,
        request: Optional[fastapi.Request] = None,
1114
    ) -> Tuple[bool, str]:
1115
1116
        self.auto_create_handle_loop()
        assert (
1117
1118
            self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
        ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
1119

1120
1121
1122
        if obj.abort_all_requests:
            self.abort_request(abort_all=True)

1123
1124
1125
        # This means that weight sync
        # cannot run while requests are in progress.
        async with self.model_update_lock.writer_lock:
1126
            result = (await self.update_weights_from_distributed_communicator(obj))[0]
1127
            return result.success, result.message
1128

1129
1130
1131
1132
1133
1134
1135
    async def update_weights_from_tensor(
        self,
        obj: UpdateWeightsFromTensorReqInput,
        request: Optional[fastapi.Request] = None,
    ) -> Tuple[bool, str]:
        self.auto_create_handle_loop()
        assert (
1136
1137
            self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
        ), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
1138

1139
1140
1141
        if obj.abort_all_requests:
            self.abort_request(abort_all=True)

1142
1143
1144
1145
1146
1147
        # This means that weight sync
        # cannot run while requests are in progress.
        async with self.model_update_lock.writer_lock:
            result = (await self.update_weights_from_tensor_communicator(obj))[0]
            return result.success, result.message

1148
1149
1150
1151
1152
1153
1154
    async def load_lora_adapter(
        self,
        obj: LoadLoRAAdapterReqInput,
        _: Optional[fastapi.Request] = None,
    ) -> LoadLoRAAdapterReqOutput:
        self.auto_create_handle_loop()

1155
1156
1157
1158
1159
        try:
            if not self.server_args.enable_lora:
                raise ValueError(
                    "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
                )
1160

1161
1162
1163
1164
1165
1166
1167
1168
1169
            # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
            # with dp_size > 1.
            assert (
                self.server_args.dp_size == 1
            ), "dp_size must be 1 for dynamic lora loading"
            logger.info(
                "Start load Lora adapter. Lora name=%s, path=%s",
                obj.lora_name,
                obj.lora_path,
1170
1171
            )

1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
            async with self.lora_update_lock:
                if (
                    self.server_args.max_loaded_loras is not None
                    and self.lora_registry.num_registered_loras
                    >= self.server_args.max_loaded_loras
                ):
                    raise ValueError(
                        f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
                        f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
                        "Please unload some LoRA adapters before loading new ones."
                    )
1183

1184
1185
1186
1187
                # Generate new uniquely identifiable LoRARef object.
                new_adapter = LoRARef(
                    lora_name=obj.lora_name,
                    lora_path=obj.lora_path,
1188
                    pinned=obj.pinned,
1189
                )
1190

1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
                # Trigger the actual loading operation at the backend processes.
                obj.lora_id = new_adapter.lora_id
                result = (await self.update_lora_adapter_communicator(obj))[0]

                # Register the LoRA adapter only after loading is successful.
                if result.success:
                    await self.lora_registry.register(new_adapter)

                return result
        except ValueError as e:
            return LoadLoRAAdapterReqOutput(
                success=False,
                error_message=str(e),
            )
1205
1206
1207
1208
1209
1210
1211
1212

    async def unload_lora_adapter(
        self,
        obj: UnloadLoRAAdapterReqInput,
        _: Optional[fastapi.Request] = None,
    ) -> UnloadLoRAAdapterReqOutput:
        self.auto_create_handle_loop()

1213
1214
1215
1216
1217
        try:
            if not self.server_args.enable_lora:
                raise ValueError(
                    "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
                )
1218

1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
            assert (
                obj.lora_name is not None
            ), "lora_name must be provided to unload LoRA adapter"

            # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
            # with dp_size > 1.
            assert (
                self.server_args.dp_size == 1
            ), "dp_size must be 1 for dynamic lora loading"
            logger.info(
                "Start unload Lora adapter. Lora name=%s",
                obj.lora_name,
            )
1232

1233
1234
1235
1236
1237
            async with self.lora_update_lock:
                # Unregister the LoRA adapter from the registry to stop new requests for this adapter
                # from being started.
                lora_id = await self.lora_registry.unregister(obj.lora_name)
                obj.lora_id = lora_id
1238

1239
1240
1241
1242
                # Initiate the actual unloading operation at the backend processes only after all
                # ongoing requests using this LoRA adapter are finished.
                await self.lora_registry.wait_for_unload(lora_id)
                result = (await self.update_lora_adapter_communicator(obj))[0]
1243

1244
1245
                return result
        except ValueError as e:
1246
            return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
1247

1248
1249
1250
    async def get_weights_by_name(
        self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
    ):
1251
1252
1253
        self.auto_create_handle_loop()
        results = await self.get_weights_by_name_communicator(obj)
        all_parameters = [r.parameter for r in results]
1254
        if self.server_args.dp_size == 1:
1255
            return all_parameters[0]
1256
1257
1258
        else:
            return all_parameters

1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
    async def release_memory_occupation(
        self,
        obj: ReleaseMemoryOccupationReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.release_memory_occupation_communicator(obj)

    async def resume_memory_occupation(
        self,
        obj: ResumeMemoryOccupationReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.resume_memory_occupation_communicator(obj)

1275
1276
1277
1278
1279
1280
1281
1282
    async def slow_down(
        self,
        obj: SlowDownReqInput,
        request: Optional[fastapi.Request] = None,
    ):
        self.auto_create_handle_loop()
        await self.slow_down_communicator(obj)

1283
1284
1285
    async def open_session(
        self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
    ):
1286
        self.auto_create_handle_loop()
1287

1288
1289
1290
1291
1292
        if obj.session_id is None:
            obj.session_id = uuid.uuid4().hex
        elif obj.session_id in self.session_futures:
            return None

1293
        self.send_to_scheduler.send_pyobj(obj)
1294
1295
1296
1297

        self.session_futures[obj.session_id] = asyncio.Future()
        session_id = await self.session_futures[obj.session_id]
        del self.session_futures[obj.session_id]
1298
1299
1300
1301
1302
1303
1304
        return session_id

    async def close_session(
        self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
    ):
        await self.send_to_scheduler.send_pyobj(obj)

1305
    async def get_internal_state(self) -> List[Dict[Any, Any]]:
1306
        req = GetInternalStateReq()
1307
        responses: List[GetInternalStateReqOutput] = (
1308
1309
            await self.get_internal_state_communicator(req)
        )
1310
1311
        # Many DP ranks
        return [res.internal_state for res in responses]
1312

1313
1314
1315
1316
1317
1318
1319
1320
    async def set_internal_state(
        self, obj: SetInternalStateReq
    ) -> SetInternalStateReqOutput:
        responses: List[SetInternalStateReqOutput] = (
            await self.set_internal_state_communicator(obj)
        )
        return [res.internal_state for res in responses]

1321
1322
1323
1324
1325
1326
1327
1328
    async def get_load(self) -> dict:
        # TODO(lsyin): fake load report server
        if not self.current_load_lock.locked():
            async with self.current_load_lock:
                internal_state = await self.get_internal_state()
                self.current_load = internal_state[0]["load"]
        return {"load": self.current_load}

1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
    def get_log_request_metadata(self):
        max_length = None
        skip_names = None
        out_skip_names = None
        if self.log_requests:
            if self.log_requests_level == 0:
                max_length = 1 << 30
                skip_names = set(
                    [
                        "text",
                        "input_ids",
                        "input_embeds",
                        "image_data",
                        "audio_data",
                        "lora_path",
1344
1345
1346
1347
1348
1349
1350
                        "sampling_params",
                    ]
                )
                out_skip_names = set(
                    [
                        "text",
                        "output_ids",
1351
                        "embedding",
1352
1353
1354
                    ]
                )
            elif self.log_requests_level == 1:
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
                max_length = 1 << 30
                skip_names = set(
                    [
                        "text",
                        "input_ids",
                        "input_embeds",
                        "image_data",
                        "audio_data",
                        "lora_path",
                    ]
                )
                out_skip_names = set(
                    [
                        "text",
                        "output_ids",
1370
                        "embedding",
1371
1372
                    ]
                )
1373
            elif self.log_requests_level == 2:
1374
1375
                max_length = 2048
            elif self.log_requests_level == 3:
1376
1377
1378
1379
1380
1381
1382
                max_length = 1 << 30
            else:
                raise ValueError(
                    f"Invalid --log-requests-level: {self.log_requests_level=}"
                )
        return max_length, skip_names, out_skip_names

1383
    def configure_logging(self, obj: ConfigureLoggingReq):
1384
1385
1386
1387
1388
1389
1390
1391
        if obj.log_requests is not None:
            self.log_requests = obj.log_requests
        if obj.log_requests_level is not None:
            self.log_requests_level = obj.log_requests_level
        if obj.dump_requests_folder is not None:
            self.dump_requests_folder = obj.dump_requests_folder
        if obj.dump_requests_threshold is not None:
            self.dump_requests_threshold = obj.dump_requests_threshold
1392
1393
        if obj.crash_dump_folder is not None:
            self.crash_dump_folder = obj.crash_dump_folder
1394
        logging.info(f"Config logging: {obj=}")
1395
        self.log_request_metadata = self.get_log_request_metadata()
1396

Lianmin Zheng's avatar
Lianmin Zheng committed
1397
    def create_abort_task(self, obj: GenerateReqInput):
1398
1399
        # Abort the request if the client is disconnected.
        async def abort_request():
Lianmin Zheng's avatar
Lianmin Zheng committed
1400
            await asyncio.sleep(2)
1401
1402
1403
            if obj.is_single:
                self.abort_request(obj.rid)
            else:
1404
                for rid in obj.rid:
1405
1406
1407
1408
1409
1410
                    self.abort_request(rid)

        background_tasks = BackgroundTasks()
        background_tasks.add_task(abort_request)
        return background_tasks

1411
    def auto_create_handle_loop(self):
1412
        if self.no_create_loop:
1413
1414
            return

1415
        self.no_create_loop = True
Lianmin Zheng's avatar
Lianmin Zheng committed
1416
        loop = asyncio.get_event_loop()
1417
1418
1419
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.handle_loop))
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1420

1421
1422
        self.event_loop = loop

1423
1424
1425
1426
        # We cannot add signal handler when the tokenizer manager is not in
        # the main thread due to the CPython limitation.
        if threading.current_thread() is threading.main_thread():
            signal_handler = SignalHandler(self)
1427
1428
1429
1430
1431
            loop.add_signal_handler(signal.SIGTERM, signal_handler.sigterm_handler)
            # Update the signal handler for the process. It overrides the sigquit handler in the launch phase.
            loop.add_signal_handler(
                signal.SIGQUIT, signal_handler.running_phase_sigquit_handler
            )
1432
1433
1434
1435
1436
1437
        else:
            logger.warning(
                "Signal handler is not added because the tokenizer manager is "
                "not in the main thread. This disables graceful shutdown of the "
                "tokenizer manager when SIGTERM is received."
            )
1438
1439
1440
        self.asyncio_tasks.add(
            loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
        )
1441

1442
1443
1444
1445
1446
1447
    def dump_requests_before_crash(self):
        if self.crash_dump_performed:
            logger.info(
                "SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
            )
            return
1448

1449
1450
1451
        if not self.crash_dump_folder:
            return

1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
        logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
        self.crash_dump_performed = True

        # Check if NFS directory is available
        # expected_nfs_dir = "/" + self.crash_dump_folder.lstrip("/").split("/")[0]
        # use_nfs_dir = os.path.isdir(expected_nfs_dir) and os.access(
        #     expected_nfs_dir, os.W_OK
        # )
        use_nfs_dir = False
        if not use_nfs_dir:
            logger.error(
                f"Expected NFS directory is not available or writable. Uploading to GCS."
            )

1466
1467
1468
1469
1470
1471
1472
1473
1474
        data_to_dump = []
        if self.crash_dump_request_list:
            data_to_dump.extend(self.crash_dump_request_list)

        # Add unfinished requests from rid_to_state
        unfinished_requests = []
        for rid, state in self.rid_to_state.items():
            if not state.finished:
                unfinished_requests.append(
1475
1476
1477
1478
1479
1480
                    (
                        state.obj,
                        state.out_list[-1] if state.out_list else {},
                        state.created_time,
                        time.time(),
                    )
1481
1482
1483
1484
1485
1486
1487
                )
        if unfinished_requests:
            data_to_dump.extend(unfinished_requests)

        if not data_to_dump:
            return

1488
        object_name = f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl'
1489
1490
1491
        filename = os.path.join(
            self.crash_dump_folder,
            os.getenv("HOSTNAME", None),
1492
            object_name,
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
        )

        os.makedirs(os.path.dirname(filename), exist_ok=True)
        # Include server_args in the dump
        data_to_dump_with_server_args = {
            "server_args": self.server_args,
            "requests": data_to_dump,
        }
        with open(filename, "wb") as f:
            pickle.dump(data_to_dump_with_server_args, f)
        logger.error(
            f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}"
        )

1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
        def _upload_file_to_gcs(bucket_name, source_file_path, object_name):
            from google.cloud import storage

            client = storage.Client()
            bucket = client.bucket(bucket_name)
            blob = bucket.blob(object_name)
            blob.upload_from_filename(source_file_path, if_generation_match=0)
            logger.error(
                f"Successfully uploaded {source_file_path} to gs://{bucket_name}/{object_name}"
            )

        if not use_nfs_dir:
            _upload_file_to_gcs(
                "sglang_crash_dump",
                filename,
                os.getenv("HOSTNAME", None) + "/" + object_name,
            )

1525
1526
    async def sigterm_watchdog(self):
        while not self.gracefully_exit:
1527
            await asyncio.sleep(5)
1528

1529
        # Drain requests
1530
        while True:
1531
            remain_num_req = len(self.rid_to_state)
1532

1533
            if self.health_check_failed:
1534
                # if health check failed, we should exit immediately
1535
1536
1537
1538
                logger.error(
                    "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
                    remain_num_req,
                )
1539
                self.dump_requests_before_crash()
1540
                break
1541
1542
1543
1544
1545
1546
1547
1548

            elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
                # if force shutdown flag set, exit immediately
                logger.error(
                    "Signal SIGTERM received while force shutdown flag set. Force exiting... remaining number of requests: %d",
                    remain_num_req,
                )
                break
1549

1550
            logger.info(
1551
                f"Gracefully exiting... remaining number of requests {remain_num_req}"
1552
1553
1554
1555
            )
            if remain_num_req > 0:
                await asyncio.sleep(5)
            else:
1556
                self.dump_requests_before_crash()
1557
1558
                break

1559
        kill_process_tree(os.getpid(), include_parent=True)
1560
        sys.exit(0)
1561

Lianmin Zheng's avatar
Lianmin Zheng committed
1562
    async def handle_loop(self):
1563
        """The event loop that handles requests"""
Lianmin Zheng's avatar
Lianmin Zheng committed
1564
        while True:
1565
            recv_obj = await self.recv_from_detokenizer.recv_pyobj()
1566
1567
1568
1569
1570
1571
1572
            # In multi-worker mode, distribute results to corresponding workers
            if self.server_args.tokenizer_worker_num > 1 and self.is_main:
                await self._distribute_result_to_workers(recv_obj)
            else:
                # In single worker mode, process directly
                self._result_dispatcher(recv_obj)

1573
            self.last_receive_tstamp = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
1574

1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
    def init_tokenizer_mapping(self, recv_obj: MultiTokenizerRegisterReq):
        """init tokenizer mapping from register request"""
        if isinstance(recv_obj.rids, list):
            worker_ids = get_workerids_from_rids(recv_obj.rids)
        else:
            raise RuntimeError(f"tokenizer_worker_num > 1, recv_obj.rids must be list")

        for worker_id in worker_ids:
            ipc_name = recv_obj.ipc_name
            worker_id_int = int(worker_id)

            if worker_id_int not in self.tokenizer_mapping:
                socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
                self.tokenizer_mapping[worker_id_int] = socket
                logger.info(
                    f"Main Tokenizer Manager Created ZMQ socket for worker {worker_id} with ipc_name {ipc_name}"
                )
            else:
                logger.info(
                    f"ZMQ socket for worker {worker_id} already exists, skipping creation"
                )

    async def _distribute_result_to_workers(self, recv_obj):
        """Distribute result to corresponding workers based on rid"""

        worker_ids = get_workerids_from_rids(recv_obj.rids)
        if len(worker_ids) == 0:
            self._result_dispatcher(recv_obj)
            return

        if not hasattr(self, "tokenizer_mapping"):
            self.tokenizer_mapping = {}

        # Create ZMQ context if needed
        if not hasattr(self, "_zmq_context"):
            self._zmq_context = zmq.Context()

        # Distribute result to each worker
        for i, worker_id in enumerate(worker_ids):
            if worker_id not in self.tokenizer_mapping:
                if isinstance(recv_obj, MultiTokenizerRegisterReq):
                    self.init_tokenizer_mapping(recv_obj)
                else:
                    logger.error(
                        f"Worker {worker_id} not registered and not found in tokenizer mapping . "
                        "Please ensure the worker is registered correctly."
                    )
                continue
            else:
                if isinstance(recv_obj, MultiTokenizerRegisterReq):
                    continue

            if not isinstance(
                recv_obj,
                (
                    BatchStrOut,
                    BatchEmbeddingOut,
                    BatchTokenIDOut,
                    BatchMultimodalOut,
                ),
            ):
                # Send to worker
                self.tokenizer_mapping[worker_id].send_pyobj(recv_obj)
            else:
                if isinstance(recv_obj, BatchTokenIDOut):
                    new_recv_obj = BatchTokenIDOut(
                        [recv_obj.rids[i]],
                        (
                            [recv_obj.finished_reasons[i]]
                            if len(recv_obj.finished_reasons) > i
                            else None
                        ),
                        (
                            [recv_obj.decoded_texts[i]]
                            if len(recv_obj.decoded_texts) > i
                            else None
                        ),
                        (
                            [recv_obj.decode_ids[i]]
                            if len(recv_obj.decode_ids) > i
                            else None
                        ),
                        (
                            [recv_obj.read_offsets[i]]
                            if len(recv_obj.read_offsets) > i
                            else None
                        ),
                        (
                            [recv_obj.output_ids[i]]
                            if recv_obj.output_ids and len(recv_obj.output_ids) > i
                            else None
                        ),
                        (
                            [recv_obj.skip_special_tokens[i]]
                            if len(recv_obj.skip_special_tokens) > i
                            else None
                        ),
                        (
                            [recv_obj.spaces_between_special_tokens[i]]
                            if len(recv_obj.spaces_between_special_tokens) > i
                            else None
                        ),
                        (
                            [recv_obj.no_stop_trim[i]]
                            if len(recv_obj.no_stop_trim) > i
                            else None
                        ),
                        (
                            [recv_obj.prompt_tokens[i]]
                            if len(recv_obj.prompt_tokens) > i
                            else None
                        ),
                        (
                            [recv_obj.completion_tokens[i]]
                            if len(recv_obj.completion_tokens) > i
                            else None
                        ),
                        (
                            [recv_obj.cached_tokens[i]]
                            if len(recv_obj.cached_tokens) > i
                            else None
                        ),
                        (
                            [recv_obj.spec_verify_ct[i]]
                            if len(recv_obj.spec_verify_ct) > i
                            else None
                        ),
                        (
                            [recv_obj.input_token_logprobs_val[i]]
                            if recv_obj.input_token_logprobs_val
                            else None
                        ),
                        (
                            [recv_obj.input_token_logprobs_idx[i]]
                            if recv_obj.input_token_logprobs_idx
                            else None
                        ),
                        (
                            [recv_obj.output_token_logprobs_val[i]]
                            if recv_obj.output_token_logprobs_val
                            else None
                        ),
                        (
                            [recv_obj.output_token_logprobs_idx[i]]
                            if recv_obj.output_token_logprobs_idx
                            else None
                        ),
                        (
                            [recv_obj.input_top_logprobs_val[i]]
                            if recv_obj.input_top_logprobs_val
                            else None
                        ),
                        (
                            [recv_obj.input_top_logprobs_idx[i]]
                            if recv_obj.input_top_logprobs_idx
                            else None
                        ),
                        (
                            [recv_obj.output_top_logprobs_val[i]]
                            if recv_obj.output_top_logprobs_val
                            else None
                        ),
                        (
                            [recv_obj.output_top_logprobs_idx[i]]
                            if recv_obj.output_top_logprobs_idx
                            else None
                        ),
                        (
                            [recv_obj.input_token_ids_logprobs_val[i]]
                            if recv_obj.input_token_ids_logprobs_val
                            else None
                        ),
                        (
                            [recv_obj.input_token_ids_logprobs_idx[i]]
                            if recv_obj.input_token_ids_logprobs_idx
                            else None
                        ),
                        (
                            [recv_obj.output_token_ids_logprobs_val[i]]
                            if recv_obj.output_token_ids_logprobs_val
                            else None
                        ),
                        (
                            [recv_obj.output_token_ids_logprobs_idx[i]]
                            if recv_obj.output_token_ids_logprobs_idx
                            else None
                        ),
                        (
                            [recv_obj.output_hidden_states[i]]
                            if recv_obj.output_hidden_states
                            else None
                        ),
                    )
                elif isinstance(recv_obj, BatchEmbeddingOut):
                    new_recv_obj = BatchEmbeddingOut(
                        [recv_obj.rids[i]],
                        (
                            [recv_obj.finished_reasons[i]]
                            if len(recv_obj.finished_reasons) > i
                            else None
                        ),
                        (
                            [recv_obj.embeddings[i]]
                            if len(recv_obj.embeddings) > i
                            else None
                        ),
                        (
                            [recv_obj.prompt_tokens[i]]
                            if len(recv_obj.prompt_tokens) > i
                            else None
                        ),
                        (
                            [recv_obj.cached_tokens[i]]
                            if len(recv_obj.cached_tokens) > i
                            else None
                        ),
                    )
                elif isinstance(recv_obj, BatchStrOut):
                    new_recv_obj = BatchStrOut(
                        [recv_obj.rids[i]],
                        (
                            [recv_obj.finished_reasons[i]]
                            if len(recv_obj.finished_reasons) > i
                            else None
                        ),
                        (
                            [recv_obj.output_strs[i]]
                            if len(recv_obj.output_strs) > i
                            else None
                        ),
                        (
                            [recv_obj.output_ids[i]]
                            if recv_obj.output_ids and len(recv_obj.output_ids) > i
                            else None
                        ),
                        (
                            [recv_obj.prompt_tokens[i]]
                            if len(recv_obj.prompt_tokens) > i
                            else None
                        ),
                        (
                            [recv_obj.completion_tokens[i]]
                            if len(recv_obj.completion_tokens) > i
                            else None
                        ),
                        (
                            [recv_obj.cached_tokens[i]]
                            if len(recv_obj.cached_tokens) > i
                            else None
                        ),
                        (
                            [recv_obj.spec_verify_ct[i]]
                            if len(recv_obj.spec_verify_ct) > i
                            else None
                        ),
                        (
                            [recv_obj.input_token_logprobs_val[i]]
                            if recv_obj.input_token_logprobs_val
                            else None
                        ),
                        (
                            [recv_obj.input_token_logprobs_idx[i]]
                            if recv_obj.input_token_logprobs_idx
                            else None
                        ),
                        (
                            [recv_obj.output_token_logprobs_val[i]]
                            if recv_obj.output_token_logprobs_val
                            else None
                        ),
                        (
                            [recv_obj.output_token_logprobs_idx[i]]
                            if recv_obj.output_token_logprobs_idx
                            else None
                        ),
                        (
                            [recv_obj.input_top_logprobs_val[i]]
                            if recv_obj.input_top_logprobs_val
                            else None
                        ),
                        (
                            [recv_obj.input_top_logprobs_idx[i]]
                            if recv_obj.input_top_logprobs_idx
                            else None
                        ),
                        (
                            [recv_obj.output_top_logprobs_val[i]]
                            if recv_obj.output_top_logprobs_val
                            else None
                        ),
                        (
                            [recv_obj.output_top_logprobs_idx[i]]
                            if recv_obj.output_top_logprobs_idx
                            else None
                        ),
                        (
                            [recv_obj.input_token_ids_logprobs_val[i]]
                            if recv_obj.input_token_ids_logprobs_val
                            else None
                        ),
                        (
                            [recv_obj.input_token_ids_logprobs_idx[i]]
                            if recv_obj.input_token_ids_logprobs_idx
                            else None
                        ),
                        (
                            [recv_obj.output_token_ids_logprobs_val[i]]
                            if recv_obj.output_token_ids_logprobs_val
                            else None
                        ),
                        (
                            [recv_obj.output_token_ids_logprobs_idx[i]]
                            if recv_obj.output_token_ids_logprobs_idx
                            else None
                        ),
                        (
                            [recv_obj.output_hidden_states[i]]
                            if recv_obj.output_hidden_states
                            else None
                        ),
                    )
                elif isinstance(recv_obj, BatchMultimodalOut):
                    new_recv_obj = BatchMultimodalOut(
                        [recv_obj.rids[i]],
                        (
                            [recv_obj.finished_reasons[i]]
                            if len(recv_obj.finished_reasons) > i
                            else None
                        ),
                        ([recv_obj.outputs[i]] if len(recv_obj.outputs) > i else None),
                        (
                            [recv_obj.prompt_tokens[i]]
                            if len(recv_obj.prompt_tokens) > i
                            else None
                        ),
                        (
                            [recv_obj.completion_tokens[i]]
                            if len(recv_obj.completion_tokens) > i
                            else None
                        ),
                        (
                            [recv_obj.cached_tokens[i]]
                            if len(recv_obj.cached_tokens) > i
                            else None
                        ),
                    )
                try:
                    self.tokenizer_mapping[worker_id].send_pyobj(new_recv_obj)
                except zmq.ZMQError as e:
                    raise RuntimeError(
                        f"Failed to send result to worker {worker_id}: {e}"
                    ) from e

    def register_to_main_tokenizer_manager(self):
        """Register this worker to the main TokenizerManager"""
        req = MultiTokenizerRegisterReq()
        req.rids = [f"{self.worker_id}_registertokenizer"]
        req.ipc_name = self.tokenizer_ipc_name
        self.send_to_scheduler.send_pyobj(req)
        time.sleep(5)

1936
    def _handle_batch_output(
1937
1938
1939
1940
        self,
        recv_obj: Union[
            BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
        ],
1941
1942
1943
1944
    ):
        for i, rid in enumerate(recv_obj.rids):
            state = self.rid_to_state.get(rid, None)
            if state is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1945
1946
1947
                logger.error(
                    f"Received output for {rid=} but the state was deleted in TokenizerManager."
                )
1948
                continue
1949
1950
1951
            originRid = rid
            if self.server_args.tokenizer_worker_num > 1:
                originRid = get_origin_rid(rid)
1952
            # Build meta_info and return value
1953
            meta_info = {
1954
                "id": originRid,
1955
1956
1957
1958
1959
1960
1961
                "finish_reason": recv_obj.finished_reasons[i],
                "prompt_tokens": recv_obj.prompt_tokens[i],
            }

            if getattr(state.obj, "return_logprob", False):
                self.convert_logprob_style(
                    meta_info,
1962
                    state,
1963
                    state.obj.top_logprobs_num,
1964
                    state.obj.token_ids_logprob,
1965
1966
                    state.obj.return_text_in_logprobs
                    and not self.server_args.skip_tokenizer_init,
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
                    recv_obj,
                    i,
                )

            if not isinstance(recv_obj, BatchEmbeddingOut):
                meta_info.update(
                    {
                        "completion_tokens": recv_obj.completion_tokens[i],
                        "cached_tokens": recv_obj.cached_tokens[i],
                    }
                )

1979
            if getattr(recv_obj, "output_hidden_states", None):
1980
1981
1982
                meta_info["hidden_states"] = recv_obj.output_hidden_states[i]

            if isinstance(recv_obj, BatchStrOut):
1983
                state.text += recv_obj.output_strs[i]
1984
1985
1986
1987
1988
1989
1990
1991
                if state.obj.stream:
                    state.output_ids.extend(recv_obj.output_ids[i])
                    output_token_ids = state.output_ids[state.last_output_offset :]
                    state.last_output_offset = len(state.output_ids)
                else:
                    state.output_ids.extend(recv_obj.output_ids[i])
                    output_token_ids = state.output_ids.copy()

1992
                out_dict = {
1993
                    "text": state.text,
1994
                    "output_ids": output_token_ids,
1995
1996
1997
                    "meta_info": meta_info,
                }
            elif isinstance(recv_obj, BatchTokenIDOut):
1998
                if self.server_args.stream_output and state.obj.stream:
1999
2000
2001
                    state.output_ids.extend(recv_obj.output_ids[i])
                    output_token_ids = state.output_ids[state.last_output_offset :]
                    state.last_output_offset = len(state.output_ids)
2002
                else:
2003
                    state.output_ids.extend(recv_obj.output_ids[i])
2004
                    output_token_ids = state.output_ids.copy()
2005

2006
                out_dict = {
2007
                    "output_ids": output_token_ids,
2008
2009
                    "meta_info": meta_info,
                }
2010
            elif isinstance(recv_obj, BatchMultimodalOut):
2011
                raise NotImplementedError("BatchMultimodalOut not implemented")
2012
2013
2014
2015
2016
2017
2018
2019
            else:
                assert isinstance(recv_obj, BatchEmbeddingOut)
                out_dict = {
                    "embedding": recv_obj.embeddings[i],
                    "meta_info": meta_info,
                }

            state.finished = recv_obj.finished_reasons[i] is not None
2020
2021
2022
2023
2024
            if state.finished:
                if self.server_args.speculative_algorithm:
                    meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
                state.finished_time = time.time()
                meta_info["e2e_latency"] = state.finished_time - state.created_time
Lianmin Zheng's avatar
Lianmin Zheng committed
2025
                del self.rid_to_state[rid]
2026
2027

            state.out_list.append(out_dict)
2028
2029
            state.event.set()

2030
            # Log metrics and dump
2031
2032
2033
2034
            if self.enable_metrics and state.obj.log_metrics:
                self.collect_metrics(state, recv_obj, i)
            if self.dump_requests_folder and state.finished and state.obj.log_metrics:
                self.dump_requests(state, out_dict)
2035
2036
            if self.crash_dump_folder and state.finished and state.obj.log_metrics:
                self.record_request_for_crash_dump(state, out_dict)
2037
2038
2039
2040

    def convert_logprob_style(
        self,
        meta_info: dict,
2041
        state: ReqState,
2042
        top_logprobs_num: int,
2043
        token_ids_logprob: List[int],
2044
2045
2046
2047
        return_text_in_logprobs: bool,
        recv_obj: BatchStrOut,
        recv_obj_index: int,
    ):
2048
2049
2050
        if recv_obj.input_token_logprobs_val is None:
            return

2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
        if len(recv_obj.input_token_logprobs_val) > 0:
            state.input_token_logprobs_val.extend(
                recv_obj.input_token_logprobs_val[recv_obj_index]
            )
            state.input_token_logprobs_idx.extend(
                recv_obj.input_token_logprobs_idx[recv_obj_index]
            )
        state.output_token_logprobs_val.extend(
            recv_obj.output_token_logprobs_val[recv_obj_index]
        )
        state.output_token_logprobs_idx.extend(
            recv_obj.output_token_logprobs_idx[recv_obj_index]
        )
2064
        meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
2065
2066
            state.input_token_logprobs_val,
            state.input_token_logprobs_idx,
2067
2068
2069
            return_text_in_logprobs,
        )
        meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
2070
2071
            state.output_token_logprobs_val,
            state.output_token_logprobs_idx,
2072
2073
2074
2075
            return_text_in_logprobs,
        )

        if top_logprobs_num > 0:
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
            if len(recv_obj.input_top_logprobs_val) > 0:
                state.input_top_logprobs_val.extend(
                    recv_obj.input_top_logprobs_val[recv_obj_index]
                )
                state.input_top_logprobs_idx.extend(
                    recv_obj.input_top_logprobs_idx[recv_obj_index]
                )
            state.output_top_logprobs_val.extend(
                recv_obj.output_top_logprobs_val[recv_obj_index]
            )
            state.output_top_logprobs_idx.extend(
                recv_obj.output_top_logprobs_idx[recv_obj_index]
            )
2089
            meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
2090
2091
                state.input_top_logprobs_val,
                state.input_top_logprobs_idx,
2092
2093
2094
                return_text_in_logprobs,
            )
            meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
2095
2096
                state.output_top_logprobs_val,
                state.output_top_logprobs_idx,
2097
2098
2099
                return_text_in_logprobs,
            )

2100
        if token_ids_logprob is not None:
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
            if len(recv_obj.input_token_ids_logprobs_val) > 0:
                state.input_token_ids_logprobs_val.extend(
                    recv_obj.input_token_ids_logprobs_val[recv_obj_index]
                )
                state.input_token_ids_logprobs_idx.extend(
                    recv_obj.input_token_ids_logprobs_idx[recv_obj_index]
                )
            state.output_token_ids_logprobs_val.extend(
                recv_obj.output_token_ids_logprobs_val[recv_obj_index]
            )
            state.output_token_ids_logprobs_idx.extend(
                recv_obj.output_token_ids_logprobs_idx[recv_obj_index]
            )
2114
            meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
2115
2116
                state.input_token_ids_logprobs_val,
                state.input_token_ids_logprobs_idx,
2117
2118
2119
2120
                return_text_in_logprobs,
            )
            meta_info["output_token_ids_logprobs"] = (
                self.detokenize_top_logprobs_tokens(
2121
2122
                    state.output_token_ids_logprobs_val,
                    state.output_token_ids_logprobs_idx,
2123
2124
2125
2126
                    return_text_in_logprobs,
                )
            )

2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
    def detokenize_logprob_tokens(
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
    ):
        if not decode_to_text:
            return [
                (logprob, token_id, None)
                for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
            ]
        else:
            assert self.tokenizer is not None
            token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
            return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))

    def detokenize_top_logprobs_tokens(
        self,
        token_logprobs_val: List[float],
        token_logprobs_idx: List[int],
        decode_to_text: bool,
    ):
        # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
        # We should batch all top-k tokens in all positions.
        ret = []
        for i in range(len(token_logprobs_val)):
            if token_logprobs_val[i]:
                ret.append(
                    self.detokenize_logprob_tokens(
                        token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
                    )
                )
            else:
                ret.append(None)
        return ret

    def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
        completion_tokens = (
            recv_obj.completion_tokens[i]
            if getattr(recv_obj, "completion_tokens", None)
            else 0
        )

2170
2171
2172
2173
        if (
            state.first_token_time == 0.0
            and self.disaggregation_mode != DisaggregationMode.PREFILL
        ):
2174
2175
            state.first_token_time = state.last_time = time.time()
            state.last_completion_tokens = completion_tokens
2176
2177
2178
2179
            self.metrics_collector.observe_time_to_first_token(
                state.first_token_time - state.created_time
            )
        else:
2180
2181
2182
2183
2184
2185
2186
            num_new_tokens = completion_tokens - state.last_completion_tokens
            if num_new_tokens:
                new_time = time.time()
                interval = new_time - state.last_time
                self.metrics_collector.observe_inter_token_latency(
                    interval,
                    num_new_tokens,
2187
                )
2188
2189
                state.last_time = new_time
                state.last_completion_tokens = completion_tokens
2190
2191

        if state.finished:
2192
2193
2194
2195
2196
2197
            has_grammar = (
                state.obj.sampling_params.get("json_schema", None)
                or state.obj.sampling_params.get("regex", None)
                or state.obj.sampling_params.get("ebnf", None)
                or state.obj.sampling_params.get("structural_tag", None)
            )
2198
            self.metrics_collector.observe_one_finished_request(
2199
2200
                recv_obj.prompt_tokens[i],
                completion_tokens,
2201
                recv_obj.cached_tokens[i],
2202
                state.finished_time - state.created_time,
2203
                has_grammar,
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
            )

    def dump_requests(self, state: ReqState, out_dict: dict):
        self.dump_request_list.append(
            (state.obj, out_dict, state.created_time, time.time())
        )

        if len(self.dump_request_list) >= self.dump_requests_threshold:
            filename = os.path.join(
                self.dump_requests_folder,
                datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
            )
2216
2217
2218
2219
2220
            self._dump_data_to_file(
                data_list=self.dump_request_list,
                filename=filename,
                log_message=f"Dump {len(self.dump_request_list)} requests to {filename}",
            )
2221
2222
            self.dump_request_list = []

2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
    def record_request_for_crash_dump(self, state: ReqState, out_dict: dict):
        current_time = time.time()
        self.crash_dump_request_list.append(
            (state.obj, out_dict, state.created_time, current_time)
        )
        # Remove requests older than 5 minutes based on finish time
        while (
            self.crash_dump_request_list
            and current_time - self.crash_dump_request_list[0][3] >= 300
        ):
            self.crash_dump_request_list.popleft()

2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
    def _dump_data_to_file(
        self, data_list: List[Tuple], filename: str, log_message: str
    ):
        logger.info(log_message)
        to_dump_with_server_args = {
            "server_args": self.server_args,
            "requests": data_list.copy(),
        }

        def background_task():
            os.makedirs(os.path.dirname(filename), exist_ok=True)
            with open(filename, "wb") as f:
                pickle.dump(to_dump_with_server_args, f)

        asyncio.create_task(asyncio.to_thread(background_task))

Lianmin Zheng's avatar
Lianmin Zheng committed
2251
    def _handle_abort_req(self, recv_obj):
2252
2253
        if is_health_check_generate_req(recv_obj):
            return
2254
        state = self.rid_to_state[recv_obj.rid]
2255
2256
2257
        rid = recv_obj.rid
        if self.server_args.tokenizer_worker_num > 1:
            rid = get_origin_rid(rid)
2258
        state.finished = True
2259
2260
2261
2262
2263
2264
2265
2266
2267
        if recv_obj.finished_reason:
            out = {
                "meta_info": {
                    "id": recv_obj.rid,
                    "finish_reason": recv_obj.finished_reason,
                },
            }
        else:
            out = {
2268
2269
                "text": "",
                "meta_info": {
2270
                    "id": rid,
2271
2272
2273
2274
2275
2276
2277
2278
                    "finish_reason": {
                        "type": "abort",
                        "message": "Abort before prefill",
                    },
                    "prompt_tokens": 0,
                    "completion_tokens": 0,
                },
            }
2279
        state.out_list.append(out)
2280
        state.event.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
2281

2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
    def _handle_open_session_req_output(self, recv_obj):
        self.session_futures[recv_obj.session_id].set_result(
            recv_obj.session_id if recv_obj.success else None
        )

    def _handle_update_weights_from_disk_req_output(self, recv_obj):
        if self.server_args.dp_size == 1:
            self.model_update_result.set_result(recv_obj)
        else:  # self.server_args.dp_size > 1
            self.model_update_tmp.append(recv_obj)
2292
            # set future if the all results are received
2293
2294
2295
            if len(self.model_update_tmp) == self.server_args.dp_size:
                self.model_update_result.set_result(self.model_update_tmp)

2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
    async def score_request(
        self,
        query: Optional[Union[str, List[int]]] = None,
        items: Optional[Union[str, List[str], List[List[int]]]] = None,
        label_token_ids: Optional[List[int]] = None,
        apply_softmax: bool = False,
        item_first: bool = False,
        request: Optional[Any] = None,
    ) -> List[List[float]]:
        """
        See Engine.score() for more details.
        """
        if label_token_ids is None:
            raise ValueError("label_token_ids must be provided")

        if self.tokenizer is not None:
            vocab_size = self.tokenizer.vocab_size
            for token_id in label_token_ids:
                if token_id >= vocab_size:
                    raise ValueError(
                        f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
                    )

        # Handle string or tokenized query/items
        if isinstance(query, str) and (
            isinstance(items, str)
            or (isinstance(items, list) and (not items or isinstance(items[0], str)))
        ):
            # Both query and items are text
            items_list = [items] if isinstance(items, str) else items
            if item_first:
                prompts = [f"{item}{query}" for item in items_list]
            else:
                prompts = [f"{query}{item}" for item in items_list]
            batch_request = GenerateReqInput(
                text=prompts,
                return_logprob=True,
                token_ids_logprob=label_token_ids,
                stream=False,
                sampling_params={"max_new_tokens": 1},
            )
        elif (
            isinstance(query, list)
            and isinstance(items, list)
            and items
            and isinstance(items[0], list)
        ):
            # Both query and items are token IDs
            if item_first:
                input_ids_list = [item + query for item in items]
            else:
                input_ids_list = [query + item for item in items]
            batch_request = GenerateReqInput(
                input_ids=input_ids_list,
                return_logprob=True,
                token_ids_logprob=label_token_ids,
                stream=False,
                sampling_params={"max_new_tokens": 1},
            )
        else:
            raise ValueError(
                "Invalid combination of query/items types for score_request."
            )

        results = await self.generate_request(batch_request, request).__anext__()
        scores = []

        for result in results:
            # Get logprobs for each token
            logprobs = {}
            for logprob, token_id, _ in result["meta_info"].get(
                "output_token_ids_logprobs", []
            )[0]:
                if token_id in label_token_ids:
                    logprobs[token_id] = logprob

            # Get scores in order of label_token_ids
            score_list = [
                logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
            ]

            # Apply softmax to logprobs if needed
            if apply_softmax:
                score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
            else:
                # Convert logprobs to probabilities if not using softmax
                score_list = [
                    math.exp(x) if x != float("-inf") else 0.0 for x in score_list
                ]

            scores.append(score_list)

        return scores

2390

2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
class ServerStatus(Enum):
    Up = "Up"
    Starting = "Starting"
    UnHealthy = "UnHealthy"
    Crashed = "Crashed"

    def is_healthy(self) -> bool:
        return self == ServerStatus.Up


2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
    is_cross_node = server_args.dist_init_addr

    if is_cross_node:
        # Fallback to default CPU transport for multi-node
        return "default"
    else:
        return "cuda_ipc"


2411
2412
2413
2414
2415
2416
2417
2418
2419
async def print_exception_wrapper(func):
    """
    Sometimes an asyncio function does not print exception.
    We do another wrapper to handle the exception.
    """
    try:
        await func()
    except Exception:
        traceback = get_exception_traceback()
2420
        logger.error(f"TokenizerManager hit an exception: {traceback}")
2421
2422
        if hasattr(func, "__self__") and isinstance(func.__self__, TokenizerManager):
            func.__self__.dump_requests_before_crash()
2423
2424
2425
2426
        kill_process_tree(os.getpid(), include_parent=True)
        sys.exit(1)


2427
class SignalHandler:
2428
    def __init__(self, tokenizer_manager: TokenizerManager):
2429
        self.tokenizer_manager = tokenizer_manager
2430

2431
    def sigterm_handler(self, signum=None, frame=None):
2432
2433
2434
        logger.warning(
            f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
        )
2435
        self.tokenizer_manager.gracefully_exit = True
2436

2437
2438
2439
2440
    def running_phase_sigquit_handler(self, signum=None, frame=None):
        logger.error(
            "Received sigquit from a child process. It usually means the child failed."
        )
2441
        self.tokenizer_manager.dump_requests_before_crash()
2442
2443
        kill_process_tree(os.getpid())

2444
2445
2446
2447
2448

T = TypeVar("T")


class _Communicator(Generic[T]):
2449
2450
    """Note: The communicator now only run up to 1 in-flight request at any time."""

2451
2452
2453
    def __init__(self, sender, fan_out: int):
        self._sender = sender
        self._fan_out = fan_out
2454
        self._result_event: Optional[asyncio.Event] = None
2455
        self._result_values: Optional[List[T]] = None
2456
        self._ready_queue: Deque[asyncio.Future] = deque()
2457
2458

    async def __call__(self, obj):
2459
        global _global_tokenizer_worker_num
2460
2461
2462
2463
2464
2465
2466
2467
        ready_event = asyncio.Event()
        if self._result_event is not None or len(self._ready_queue) > 0:
            self._ready_queue.append(ready_event)
            await ready_event.wait()
            assert self._result_event is None
            assert self._result_values is None

        if obj:
2468
2469
2470
2471
2472
2473
2474
2475
            if _global_tokenizer_worker_num > 1:
                if obj.rids is None:
                    obj.rids = f"{os.getpid()}_{uuid.uuid4().hex}_Communicator"
                else:
                    if isinstance(obj.rids, str):
                        obj.rids = f"{os.getpid()}_{obj.rids}"
                    elif isinstance(obj.rids, list):
                        obj.rids = [f"{os.getpid()}_{rid}" for rid in obj.rids]
2476
2477
2478
            self._sender.send_pyobj(obj)

        self._result_event = asyncio.Event()
2479
        self._result_values = []
2480
        await self._result_event.wait()
2481
        result_values = self._result_values
2482
2483
2484
2485
2486
        self._result_event = self._result_values = None

        if len(self._ready_queue) > 0:
            self._ready_queue.popleft().set()

2487
2488
2489
        return result_values

    def handle_recv(self, recv_obj: T):
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
        global _global_tokenizer_worker_num
        if _global_tokenizer_worker_num > 1:
            # If rids is a string and not empty, remove the prefix
            if (
                hasattr(recv_obj, "rids")
                and isinstance(recv_obj.rids, str)
                and recv_obj.rids
            ):
                recv_obj.rids = get_origin_rid(recv_obj.rids)
            # If rids is a list, remove prefix from each element
            elif hasattr(recv_obj, "rids") and isinstance(recv_obj.rids, list):
                recv_obj.rids = [get_origin_rid(rid) for rid in recv_obj.rids]

2503
2504
        self._result_values.append(recv_obj)
        if len(self._result_values) == self._fan_out:
2505
            self._result_event.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
2506
2507
2508
2509
2510
2511
2512


# Note: request abort handling logic
# We should handle all of the following cases correctly.
#
# | entrypoint | is_streaming | status          | abort engine    | cancel asyncio task   | rid_to_state                |
# | ---------- | ------------ | --------------- | --------------- | --------------------- | --------------------------- |
2513
# | http       | yes          | validation      | background task | fast api              | del in _handle_abort_req    |
Lianmin Zheng's avatar
Lianmin Zheng committed
2514
2515
# | http       | yes          | waiting queue   | background task | fast api              | del in _handle_abort_req    |
# | http       | yes          | running         | background task | fast api              | del in _handle_batch_output |
2516
# | http       | no           | validation      | http exception  | http exception        | del in _handle_abort_req    |
Lianmin Zheng's avatar
Lianmin Zheng committed
2517
2518
2519
# | http       | no           | waiting queue   | type 1          | type 1 exception      | del in _handle_abort_req    |
# | http       | no           | running         | type 3          | type 3 exception      | del in _handle_batch_output |
#