scheduler.py 115 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
import time
23
from collections import deque
Lianmin Zheng's avatar
Lianmin Zheng committed
24
from concurrent import futures
25
from dataclasses import dataclass
26
from http import HTTPStatus
27
from typing import Deque, Dict, List, Optional, Tuple, Union
28

29
import psutil
30
import setproctitle
31
import torch
32
import torch.distributed
33
import zmq
34
35
from torch.cuda import Stream as CudaStream
from torch.cuda import StreamContext as CudaStreamContext
36
from torch.distributed import barrier
37

Lianmin Zheng's avatar
Lianmin Zheng committed
38
from sglang.srt.configs.model_config import ModelConfig
39
40
41
42
from sglang.srt.constrained.base_grammar_backend import (
    INVALID_GRAMMAR_OBJ,
    create_grammar_backend,
)
Byron Hsu's avatar
Byron Hsu committed
43
44
45
46
47
from sglang.srt.disaggregation.decode import (
    DecodePreallocQueue,
    DecodeTransferQueue,
    SchedulerDisaggregationDecodeMixin,
)
48
49
50
from sglang.srt.disaggregation.decode_kvcache_offload_manager import (
    DecodeKVCacheOffloadManager,
)
Byron Hsu's avatar
Byron Hsu committed
51
52
53
54
55
56
from sglang.srt.disaggregation.prefill import (
    PrefillBootstrapQueue,
    SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
57
    MetadataBuffers,
Byron Hsu's avatar
Byron Hsu committed
58
    ReqToMetadataIdxAllocator,
59
    TransferBackend,
60
    prepare_abort,
Byron Hsu's avatar
Byron Hsu committed
61
)
62
from sglang.srt.distributed import get_pp_group, get_world_group
63
from sglang.srt.environ import envs
fzyzcjy's avatar
fzyzcjy committed
64
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
65
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
66
from sglang.srt.layers.moe import initialize_moe_config
67
68
from sglang.srt.managers.io_struct import (
    AbortReq,
69
70
    BaseBatchReq,
    BaseReq,
71
72
    BatchTokenizedEmbeddingReqInput,
    BatchTokenizedGenerateReqInput,
73
74
    ClearHiCacheReqInput,
    ClearHiCacheReqOutput,
75
    CloseSessionReqInput,
76
    DestroyWeightsUpdateGroupReqInput,
77
    ExpertDistributionReq,
78
    ExpertDistributionReqOutput,
79
    ExpertDistributionReqType,
80
81
    FlushCacheReqInput,
    FlushCacheReqOutput,
82
    FreezeGCReq,
83
84
    GetInternalStateReq,
    GetInternalStateReqOutput,
85
86
    GetLoadReqInput,
    GetLoadReqOutput,
87
    GetWeightsByNameReqInput,
88
    HealthCheckOutput,
89
90
    InitWeightsSendGroupForRemoteInstanceReqInput,
    InitWeightsSendGroupForRemoteInstanceReqOutput,
91
    InitWeightsUpdateGroupReqInput,
92
93
    LoadLoRAAdapterReqInput,
    LoadLoRAAdapterReqOutput,
94
95
    OpenSessionReqInput,
    OpenSessionReqOutput,
96
    ProfileReq,
97
98
    ReleaseMemoryOccupationReqInput,
    ResumeMemoryOccupationReqInput,
99
100
    RpcReqInput,
    RpcReqOutput,
101
102
    SendWeightsToRemoteInstanceReqInput,
    SendWeightsToRemoteInstanceReqOutput,
103
104
    SetInternalStateReq,
    SetInternalStateReqOutput,
105
106
    SlowDownReqInput,
    SlowDownReqOutput,
107
108
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
109
110
    UnloadLoRAAdapterReqInput,
    UnloadLoRAAdapterReqOutput,
Chayenne's avatar
Chayenne committed
111
    UpdateWeightFromDiskReqInput,
112
    UpdateWeightsFromDistributedReqInput,
113
    UpdateWeightsFromIPCReqInput,
114
    UpdateWeightsFromTensorReqInput,
115
)
116
from sglang.srt.managers.mm_utils import init_embedding_cache
117
from sglang.srt.managers.overlap_utils import FutureMap
118
119
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
120
    ModelWorkerBatch,
Mick's avatar
Mick committed
121
    MultimodalInputs,
122
    Req,
123
    RequestStage,
124
125
    ScheduleBatch,
)
126
127
128
129
130
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
fzyzcjy's avatar
fzyzcjy committed
131
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
132
133
134
135
from sglang.srt.managers.scheduler_metrics_mixin import (
    RECORD_STEP_TIME,
    SchedulerMetricsMixin,
)
136
137
138
from sglang.srt.managers.scheduler_output_processor_mixin import (
    SchedulerOutputProcessorMixin,
)
139
from sglang.srt.managers.scheduler_pp_mixin import SchedulerPPMixin
140
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
141
from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
142
143
144
from sglang.srt.managers.scheduler_runtime_checker_mixin import (
    SchedulerRuntimeCheckerMixin,
)
145
146
147
from sglang.srt.managers.scheduler_update_weights_mixin import (
    SchedulerUpdateWeightsMixin,
)
148
from sglang.srt.managers.session_controller import Session
149
from sglang.srt.managers.utils import GenerationBatchResult, validate_input_length
tarinkk's avatar
tarinkk committed
150
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
151
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
152
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
153
from sglang.srt.mem_cache.radix_cache import RadixCache
Hanming Lu's avatar
Hanming Lu committed
154
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
155
from sglang.srt.model_executor.forward_batch_info import ForwardMode
156
from sglang.srt.multiplex.multiplexing_mixin import SchedulerMultiplexMixin
157
from sglang.srt.parser.reasoning_parser import ReasoningParser
158
from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
159
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
160
161
from sglang.srt.tracing.trace import (
    process_tracing_init,
162
    trace_event_batch,
163
164
    trace_set_proc_propagate_context,
    trace_set_thread_info,
165
    trace_slice_batch,
166
167
168
    trace_slice_end,
    trace_slice_start,
)
169
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
170
from sglang.srt.utils import (
171
    DynamicGradMode,
172
    broadcast_pyobj,
fzyzcjy's avatar
fzyzcjy committed
173
    configure_gc_logger,
174
    configure_logger,
175
    freeze_gc,
176
    get_available_gpu_memory,
177
    get_bool_env_var,
178
    get_int_env_var,
179
    get_zmq_socket,
Lianmin Zheng's avatar
Lianmin Zheng committed
180
    kill_itself_when_parent_died,
181
    numa_bind_to_node,
182
    point_to_point_pyobj,
183
184
    require_mlp_sync,
    require_mlp_tp_gather,
185
    set_gpu_proc_affinity,
186
187
188
    set_random_seed,
    suppress_other_loggers,
)
189
190
191
192
193
from sglang.srt.utils.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
194
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
195
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
196
197
198

logger = logging.getLogger(__name__)

199
# Test retract decode for debugging purposes
200
201
TEST_RETRACT = envs.SGLANG_TEST_RETRACT.get()
TEST_RETRACT_INTERVAL = envs.SGLANG_TEST_RETRACT_INTERVAL.get()
202
TEST_RETRACT_NO_PREFILL_BS = envs.SGLANG_TEST_RETRACT_NO_PREFILL_BS.get()
203
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
204

205

206
207
208
209
210
@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor


Byron Hsu's avatar
Byron Hsu committed
211
212
class Scheduler(
    SchedulerOutputProcessorMixin,
213
214
215
    SchedulerUpdateWeightsMixin,
    SchedulerProfilerMixin,
    SchedulerMetricsMixin,
Byron Hsu's avatar
Byron Hsu committed
216
217
    SchedulerDisaggregationDecodeMixin,
    SchedulerDisaggregationPrefillMixin,
218
    SchedulerMultiplexMixin,
219
    SchedulerRuntimeCheckerMixin,
220
    SchedulerPPMixin,
Byron Hsu's avatar
Byron Hsu committed
221
):
222
223
224
225
226
227
228
229
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
Cheng Wan's avatar
Cheng Wan committed
230
        moe_ep_rank: int,
231
        pp_rank: int,
232
        dp_rank: Optional[int],
233
234
    ):
        # Parse args
235
        self.server_args = server_args
236
        self.tp_rank = tp_rank
Cheng Wan's avatar
Cheng Wan committed
237
        self.moe_ep_rank = moe_ep_rank
238
        self.pp_rank = pp_rank
239
        self.dp_rank = dp_rank
240
        self.tp_size = server_args.tp_size
Cheng Wan's avatar
Cheng Wan committed
241
        self.moe_ep_size = server_args.ep_size
242
243
        self.pp_size = server_args.pp_size
        self.dp_size = server_args.dp_size
244
        self.schedule_policy = server_args.schedule_policy
245
        self.enable_priority_scheduling = server_args.enable_priority_scheduling
246
247
248
        self.abort_on_priority_when_disabled = (
            server_args.abort_on_priority_when_disabled
        )
249
250
251
252
253
254
        self.schedule_low_priority_values_first = (
            server_args.schedule_low_priority_values_first
        )
        self.priority_scheduling_preemption_threshold = (
            server_args.priority_scheduling_preemption_threshold
        )
255
        self.enable_lora = server_args.enable_lora
256
        self.max_loras_per_batch = server_args.max_loras_per_batch
257
        self.enable_overlap = not server_args.disable_overlap_schedule
258
        self.enable_pdmux = server_args.enable_pdmux
259
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
260
        self.enable_metrics = server_args.enable_metrics
261
262
263
        self.enable_metrics_for_all_schedulers = (
            server_args.enable_metrics_for_all_schedulers
        )
264
265
266
        self.enable_kv_cache_events = bool(
            server_args.kv_events_config and tp_rank == 0
        )
267
        self.enable_trace = server_args.enable_trace
268
        self.stream_interval = server_args.stream_interval
269
270
271
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
272
        self.gpu_id = gpu_id
Lianmin Zheng's avatar
Lianmin Zheng committed
273
        self.page_size = server_args.page_size
274
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
275
        self.enable_hicache_storage = server_args.hicache_storage_backend is not None
276

Lianmin Zheng's avatar
Lianmin Zheng committed
277
        # Distributed rank info
278
        self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
279
280
281
282
283
284
285
286
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

287
288
289
        # Init model config
        self.model_config = ModelConfig.from_server_args(server_args)

290
        # Init inter-process communication
291
        self.init_sockets(server_args, port_args)
292

293
294
295
296
        # Init pdmux context
        if self.enable_pdmux:
            self.init_pdmux()

297
        # Init tokenizer
298
        self.init_tokenizer()
299

300
301
302
        # Init moe config
        self.init_moe_config()

303
304
305
306
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
307

308
        # Launch a tensor parallel worker
309
310
311
        from sglang.srt.managers.tp_worker import TpModelWorker

        self.tp_worker = TpModelWorker(
312
            server_args=server_args,
313
314
            gpu_id=gpu_id,
            tp_rank=tp_rank,
Cheng Wan's avatar
Cheng Wan committed
315
            moe_ep_rank=moe_ep_rank,
316
            pp_rank=pp_rank,
317
            dp_rank=dp_rank,
318
            nccl_port=port_args.nccl_port,
319
        )
320

321
        # Launch a draft worker for speculative decoding
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
        draft_worker_kwargs = dict(
            gpu_id=gpu_id,
            tp_rank=tp_rank,
            moe_ep_rank=moe_ep_rank,
            server_args=server_args,
            nccl_port=port_args.nccl_port,
            target_worker=self.tp_worker,
            dp_rank=dp_rank,
        )

        if server_args.speculative_draft_load_format is not None:
            server_args.load_format = server_args.speculative_draft_load_format
            logger.info(
                f"Using draft model load_format: '{server_args.speculative_draft_load_format}'"
            )

        # Draft workers are looked up via `SpeculativeAlgorithm` registry; new
        # algorithms should register their factory instead of patching this code.
        if self.spec_algorithm.name in {"EAGLE", "EAGLE3"}:
            draft_worker_kwargs["enable_overlap"] = self.enable_overlap
        self.draft_worker = self.spec_algorithm.create_draft_worker(
            **draft_worker_kwargs
344
        )
345

346
347
348
349
350
351
        # Dispatch the model worker
        if self.spec_algorithm.is_none():
            self.model_worker = self.tp_worker
        else:
            self.model_worker = self.draft_worker

352
        # Get token and memory info from the model worker
353
354
355
356
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
357
            self.max_queued_requests,
358
            self.max_req_len,
359
360
            self.max_req_input_len,
            self.random_seed,
361
            self.device,
362
363
364
365
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
366
367
        if get_global_server_args().pp_max_micro_batch_size is None:
            get_global_server_args().pp_max_micro_batch_size = max(
368
369
370
371
372
373
                self.max_running_requests // server_args.pp_size, 1
            )

        self.tp_group = self.tp_worker.get_tp_group()
        self.tp_cpu_group = self.tp_group.cpu_group
        self.attn_tp_group = self.tp_worker.get_attention_tp_group()
374
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
375
376
377
        self.pp_group = get_pp_group()
        self.world_group = get_world_group()

378
379
380
381
382
383
384
385
386
387
388
        # With DP attention enabled, the entry rank is attn_tp_rank==0;
        # otherwise the entry rank is TP group local rank 0.
        # For #11910, use the CPU communication group to broadcast VLM Python objects,
        # avoiding any coupling with CUDA streams/devices.
        if self.server_args.enable_dp_attention:
            self.cpu_group = self.attn_tp_cpu_group
            self.is_entry_rank = self.attn_tp_rank == 0
        else:
            self.cpu_group = self.tp_cpu_group
            self.is_entry_rank = self.tp_group.rank == 0

389
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
390
        set_random_seed(self.random_seed)
391

392
        # Hybrid memory pool
Hanming Lu's avatar
Hanming Lu committed
393
        self.is_hybrid = self.tp_worker.is_hybrid
394
        self.is_hybrid_gdn = self.tp_worker.model_runner.hybrid_gdn_config is not None
395

Hanming Lu's avatar
Hanming Lu committed
396
397
398
399
400
401
        if self.is_hybrid:
            self.sliding_window_size = self.tp_worker.sliding_window_size
            self.full_tokens_per_layer, self.swa_tokens_per_layer = (
                self.tp_worker.get_tokens_per_layer_info()
            )

402
        # Print debug info
403
        if tp_rank == 0:
404
405
406
            avail_mem = get_available_gpu_memory(
                self.device, self.gpu_id, empty_cache=False
            )
407
408
409
410
411
            logger.info(
                f"max_total_num_tokens={self.max_total_num_tokens}, "
                f"chunked_prefill_size={server_args.chunked_prefill_size}, "
                f"max_prefill_tokens={self.max_prefill_tokens}, "
                f"max_running_requests={self.max_running_requests}, "
412
                f"context_len={self.model_config.context_len}, "
413
                f"{'available_cpu_mem' if self.device == 'cpu' else 'available_gpu_mem'}={avail_mem:.2f} GB"
414
            )
415

Lianmin Zheng's avatar
Lianmin Zheng committed
416
        # Init memory pool and cache
417
        self.init_memory_pool_and_cache()
418
419
420

        # Init running status
        self.waiting_queue: List[Req] = []
421
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
422
        self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
423
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
424
        self.cur_batch: Optional[ScheduleBatch] = None
425
426
        # The current split prefill batch
        self.split_prefill_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
427
        # The last forward batch
428
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
429
430
        self.forward_ct = 0
        self.forward_ct_decode = 0
431
        self.num_generated_tokens = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
432
        self.last_prefill_tokens = 0
433
        self.return_health_check_ct = 0
434
435
436
        self.num_retracted_reqs: int = 0
        self.num_paused_reqs: int = 0
        self.sessions: Dict[str, Session] = {}
437
438
439
        self.default_stream: CudaStream = torch.get_device_module(
            self.device
        ).current_stream()
440
        if self.device == "cpu":
441
            self.default_stream.synchronize = lambda: None  # No-op for CPU
442
        self.forward_sleep_time = None
443

444
445
        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
446
447
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
448
        self.chunked_req = None
449
450
451
452
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
453
        # Init the grammar backend for constrained generation
454
        self.grammar_queue: List[Req] = []
455
        if not server_args.skip_tokenizer_init:
456
            self.grammar_backend = create_grammar_backend(
457
458
459
460
                server_args,
                self.tokenizer,
                self.model_config.vocab_size,
                self.model_config.hf_eos_token_id,
461
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
462
463
        else:
            self.grammar_backend = None
464

465
        # Init schedule policy and new token estimation
466
        self.policy = SchedulePolicy(
Lianmin Zheng's avatar
Lianmin Zheng committed
467
468
469
            self.schedule_policy,
            self.tree_cache,
            self.enable_hierarchical_cache,
470
471
            self.enable_priority_scheduling,
            self.schedule_low_priority_values_first,
472
        )
473
474
        # Enable preemption for priority scheduling.
        self.try_preemption = self.enable_priority_scheduling
475
        self.init_new_token_ratio = min(
476
            envs.SGLANG_INIT_NEW_TOKEN_RATIO.get()
477
478
            * server_args.schedule_conservativeness,
            1.0,
479
        )
480
        self.min_new_token_ratio = min(
481
            self.init_new_token_ratio * envs.SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR.get(),
482
483
484
485
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
486
        ) / envs.SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS.get()
487
488
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
489
490
491
492
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
493
        self.parent_process = psutil.Process().parent()
494
495

        # Init memory saver, profiler and metric stats
496
497
498
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )
499
        self.offload_tags = set()
limingshu's avatar
limingshu committed
500
        self.init_profiler()
501
        self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
fzyzcjy's avatar
fzyzcjy committed
502
503
504
505
506
507
        self.input_blocker = (
            SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
            if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
            else None
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
508
509
510
        # Init disaggregation
        self.init_disaggregation()

511
        # Init metrics stats
512
        self.init_metrics(tp_rank, pp_rank, dp_rank)
513

514
515
516
        if self.enable_kv_cache_events:
            self.init_kv_events(server_args.kv_events_config)

517
        if envs.SGLANG_LOG_GC.get():
518
519
            configure_gc_logger()

520
521
        # Init prefill kv split size when deterministic inference is enabled with various attention backends
        self.init_deterministic_inference_config()
522

523
524
525
        # Init overlap
        self.init_overlap()

526
527
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
528
529
530
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
531
532
                (BatchTokenizedGenerateReqInput, self.handle_batch_generate_request),
                (BatchTokenizedEmbeddingReqInput, self.handle_batch_embedding_request),
533
                (FlushCacheReqInput, self.flush_cache_wrapped),
534
                (ClearHiCacheReqInput, self.clear_hicache_storage_wrapped),
535
                (AbortReq, self.abort_request),
536
537
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
538
539
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
540
                (DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group),
541
542
543
544
545
546
547
548
                (
                    InitWeightsSendGroupForRemoteInstanceReqInput,
                    self.init_weights_send_group_for_remote_instance,
                ),
                (
                    SendWeightsToRemoteInstanceReqInput,
                    self.send_weights_to_remote_instance,
                ),
549
550
551
552
553
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
554
                (UpdateWeightsFromIPCReqInput, self.update_weights_from_ipc),
555
                (GetWeightsByNameReqInput, self.get_weights_by_name),
556
557
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
558
                (SlowDownReqInput, self.slow_down),
559
                (ProfileReq, self.profile),
560
                (FreezeGCReq, self.handle_freeze_gc),
561
                (GetInternalStateReq, self.get_internal_state),
562
                (SetInternalStateReq, self.set_internal_state),
563
                (RpcReqInput, self.handle_rpc_request),
564
                (ExpertDistributionReq, self.expert_distribution_handle),
565
566
                (LoadLoRAAdapterReqInput, self.load_lora_adapter),
                (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
567
                (GetLoadReqInput, self.get_load),
568
569
570
            ]
        )

571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
    def init_sockets(self, server_args: ServerArgs, port_args: PortArgs):
        context = zmq.Context(2)
        self.idle_sleeper = None

        class SenderWrapper:
            def __init__(self, socket: zmq.Socket):
                self.socket = socket

            def send_output(
                self,
                output: Union[BaseReq, BaseBatchReq],
                recv_obj: Optional[Union[BaseReq, BaseBatchReq]] = None,
            ):
                if self.socket is None:
                    return

                if (
                    isinstance(recv_obj, BaseReq)
                    and recv_obj.http_worker_ipc is not None
                    and output.http_worker_ipc is None
                ):
                    # handle communicator reqs for multi-http worker case
                    output.http_worker_ipc = recv_obj.http_worker_ipc

                self.socket.send_pyobj(output)

        if self.pp_rank == 0 and self.attn_tp_rank == 0:
            self.recv_from_tokenizer = get_zmq_socket(
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
            )
            self.recv_from_rpc = get_zmq_socket(
                context, zmq.DEALER, port_args.rpc_ipc_name, False
            )

            send_to_tokenizer = get_zmq_socket(
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
            )
            if server_args.skip_tokenizer_init:
                # Directly send to the TokenizerManager
                send_to_detokenizer = get_zmq_socket(
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
                )
            else:
                # Send to the DetokenizerManager
                send_to_detokenizer = get_zmq_socket(
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
                )

            self.send_to_tokenizer = SenderWrapper(send_to_tokenizer)
            self.send_to_detokenizer = SenderWrapper(send_to_detokenizer)

            if self.server_args.sleep_on_idle:
                self.idle_sleeper = IdleSleeper(
                    [
                        self.recv_from_tokenizer,
                        self.recv_from_rpc,
                    ]
                )
        else:
            self.recv_from_tokenizer = None
            self.recv_from_rpc = None
            self.send_to_tokenizer = SenderWrapper(None)
            self.send_to_detokenizer = SenderWrapper(None)

        if self.current_scheduler_metrics_enabled():
            self.send_metrics_from_scheduler = get_zmq_socket(
                context, zmq.PUSH, port_args.metrics_ipc_name, False
            )

640
641
642
643
644
645
    def init_deterministic_inference_config(self):
        """Initialize deterministic inference configuration for different attention backends."""
        if not self.server_args.enable_deterministic_inference:
            self.truncation_align_size = None
            return

646
647
648
649
650
        backend_sizes = {
            "flashinfer": ("SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096),
            "triton": ("SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", 4096),
        }
        env_var, default_size = backend_sizes.get(
651
652
653
654
655
656
            self.server_args.attention_backend, (None, None)
        )
        self.truncation_align_size = (
            get_int_env_var(env_var, default_size) if env_var else None
        )

657
658
659
    def init_tokenizer(self):
        server_args = self.server_args
        self.is_generation = self.model_config.is_generation
660

661
662
663
664
665
666
667
668
669
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if self.model_config.is_multimodal:
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
670
                    use_fast=not server_args.disable_fast_image_processor,
671
                )
xm:D's avatar
xm:D committed
672
                self.tokenizer = get_tokenizer_from_processor(self.processor)
673
674
675
676
677
678
679
680
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )

Lianmin Zheng's avatar
Lianmin Zheng committed
681
682
683
684
685
686
687
688
689
        # Set reasoning_parser and think_end_id if --reasoning_parser is enabled
        if self.server_args.reasoning_parser and self.tokenizer:
            reasoning_parser = ReasoningParser(
                model_type=self.server_args.reasoning_parser, stream_reasoning=False
            )
            self.tokenizer.think_end_id = self.tokenizer.encode(
                reasoning_parser.detector.think_end_token, add_special_tokens=False
            )[0]

690
691
692
693
694
695
696
697
698
699
700
    def init_memory_pool_and_cache(self):
        server_args = self.server_args

        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            self.tp_worker.get_memory_pool()
        )

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
Hanming Lu's avatar
Hanming Lu committed
701
            if self.is_hybrid:
tarinkk's avatar
tarinkk committed
702
703
704
705
                ChunkCacheClass = SWAChunkCache
            else:
                ChunkCacheClass = ChunkCache
            self.tree_cache = ChunkCacheClass(
706
707
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
708
                page_size=self.page_size,
709
710
            )
        else:
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
            if os.environ.get("SGLANG_EXPERIMENTAL_CPP_RADIX_TREE") == "1":
                # lazy import to avoid JIT overhead
                from sglang.srt.mem_cache.radix_cache_cpp import RadixCacheCpp

                self.tree_cache = RadixCacheCpp(
                    disable=False,
                    use_hicache=self.enable_hierarchical_cache,
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool=self.token_to_kv_pool_allocator,
                    tp_cache_group=self.tp_cpu_group,
                    page_size=self.page_size,
                    hicache_ratio=server_args.hicache_ratio,
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
                    enable_kv_cache_events=self.enable_kv_cache_events,
                )
            elif self.enable_hierarchical_cache:
728
729
730
                self.tree_cache = HiRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
731
732
733
734
735
                    tp_cache_group=(
                        self.attn_tp_cpu_group
                        if self.server_args.enable_dp_attention
                        else self.tp_cpu_group
                    ),
736
                    page_size=self.page_size,
737
                    eviction_policy=server_args.radix_eviction_policy,
738
                    hicache_ratio=server_args.hicache_ratio,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
739
740
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
741
                    hicache_io_backend=server_args.hicache_io_backend,
742
                    hicache_mem_layout=server_args.hicache_mem_layout,
743
                    enable_metrics=self.enable_metrics,
744
                    hicache_storage_backend=server_args.hicache_storage_backend,
pansicheng's avatar
pansicheng committed
745
                    hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
746
747
                    model_name=server_args.served_model_name,
                    storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
Ke Bao's avatar
Ke Bao committed
748
                    is_eagle=self.spec_algorithm.is_eagle(),
749
                )
750
751
752
                self.tp_worker.register_hicache_layer_transfer_counter(
                    self.tree_cache.cache_controller.layer_done_counter
                )
Hanming Lu's avatar
Hanming Lu committed
753
754
755
756
757
758
759
            elif self.is_hybrid:
                self.tree_cache = SWARadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                    sliding_window_size=self.sliding_window_size,
                    page_size=self.page_size,
                    disable=server_args.disable_radix_cache,
760
                    is_eagle=self.spec_algorithm.is_eagle(),
Hanming Lu's avatar
Hanming Lu committed
761
                )
762
763
764
765
766
767
768
            elif self.is_hybrid_gdn:
                self.tree_cache = MambaRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                    page_size=self.page_size,
                    disable=server_args.disable_radix_cache,
                )
769
770
771
772
773
774
775
776
777
778
779
780
781
782
            elif server_args.enable_lmcache:
                from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
                    LMCRadixCache,
                )

                self.tree_cache = LMCRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                    page_size=self.page_size,
                    disable=server_args.disable_radix_cache,
                    model_config=self.model_config,
                    tp_size=self.tp_size,
                    rank=self.tp_rank,
                    tp_group=self.tp_group,
783
                    eviction_policy=server_args.radix_eviction_policy,
784
                )
785
786
787
788
            else:
                self.tree_cache = RadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
789
                    page_size=self.page_size,
790
                    disable=server_args.disable_radix_cache,
791
                    enable_kv_cache_events=self.enable_kv_cache_events,
792
                    eviction_policy=server_args.radix_eviction_policy,
Ke Bao's avatar
Ke Bao committed
793
                    is_eagle=self.spec_algorithm.is_eagle(),
794
795
                )

796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
        if (
            server_args.disaggregation_mode == "decode"
            and server_args.disaggregation_decode_enable_offload_kvcache
        ):
            self.decode_offload_manager = DecodeKVCacheOffloadManager(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                tp_group=(
                    self.attn_tp_cpu_group
                    if self.server_args.enable_dp_attention
                    else self.tp_cpu_group
                ),
                tree_cache=self.tree_cache,
                server_args=self.server_args,
            )
        else:
            self.decode_offload_manager = None

814
815
816
817
818
819
        self.decode_mem_cache_buf_multiplier = (
            1
            if self.spec_algorithm.is_none()
            else (
                server_args.speculative_num_draft_tokens
                + (
820
821
                    (server_args.speculative_eagle_topk or 1)
                    * (server_args.speculative_num_steps or 1)
822
823
                )
            )
824
        )
825

826
827
828
        embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
        init_embedding_cache(embedding_cache_size * 1024 * 1024)

Byron Hsu's avatar
Byron Hsu committed
829
    def init_disaggregation(self):
Lianmin Zheng's avatar
Lianmin Zheng committed
830
831
832
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
833
834
835
836
        self.transfer_backend = TransferBackend(
            self.server_args.disaggregation_transfer_backend
        )

Byron Hsu's avatar
Byron Hsu committed
837
838
839
840
        if (
            self.disaggregation_mode == DisaggregationMode.DECODE
        ):  # *2 for the headroom.
            buffer_size = (self.req_to_token_pool.size) * 2
Byron Hsu's avatar
Byron Hsu committed
841
            self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
Byron Hsu's avatar
Byron Hsu committed
842
843
                buffer_size
            )
844
845
            self.disagg_metadata_buffers = MetadataBuffers(
                buffer_size,
846
                hidden_size=self.model_config.hf_text_config.hidden_size,
847
                hidden_states_dtype=self.model_config.dtype,
848
849
                custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
            )
Byron Hsu's avatar
Byron Hsu committed
850
851
852

            # The decode requests polling kv cache
            self.disagg_decode_transfer_queue = DecodeTransferQueue(
853
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
854
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
855
                tp_rank=self.tp_rank,
856
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
857
858
                scheduler=self,
                tree_cache=self.tree_cache,
Byron Hsu's avatar
Byron Hsu committed
859
860
861
862
863
864
            )

            # The decode requests pending for pre-allocation
            self.disagg_decode_prealloc_queue = DecodePreallocQueue(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Byron Hsu's avatar
Byron Hsu committed
865
866
                draft_token_to_kv_pool=(
                    None
867
                    if self.draft_worker is None or self.spec_algorithm.is_ngram()
Byron Hsu's avatar
Byron Hsu committed
868
869
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
870
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
871
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
872
873
874
                scheduler=self,
                transfer_queue=self.disagg_decode_transfer_queue,
                tree_cache=self.tree_cache,
875
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
876
877
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
878
879
                dp_size=self.server_args.dp_size,
                gpu_id=self.gpu_id,
Byron Hsu's avatar
Byron Hsu committed
880
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
881
882
                max_total_num_tokens=self.max_total_num_tokens,
                prefill_pp_size=self.server_args.disaggregation_prefill_pp,
883
                num_reserved_decode_tokens=self.server_args.num_reserved_decode_tokens,
884
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
885
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
886

Byron Hsu's avatar
Byron Hsu committed
887
888
889
        elif self.disaggregation_mode == DisaggregationMode.PREFILL:
            # *2 for the headroom.
            buffer_size = self.max_running_requests * 2
Byron Hsu's avatar
Byron Hsu committed
890
            self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
Byron Hsu's avatar
Byron Hsu committed
891
892
                buffer_size
            )
893
894
            self.disagg_metadata_buffers = MetadataBuffers(
                buffer_size,
895
                hidden_size=self.model_config.hf_text_config.hidden_size,
896
                hidden_states_dtype=self.model_config.dtype,
897
898
                custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
            )
Byron Hsu's avatar
Byron Hsu committed
899

Liangsheng Yin's avatar
Liangsheng Yin committed
900
            self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
Byron Hsu's avatar
Byron Hsu committed
901
                token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
Byron Hsu's avatar
Byron Hsu committed
902
903
                draft_token_to_kv_pool=(
                    None
904
                    if self.draft_worker is None or self.spec_algorithm.is_ngram()
Byron Hsu's avatar
Byron Hsu committed
905
906
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
907
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
908
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
909
910
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
Byron Hsu's avatar
Byron Hsu committed
911
                gpu_id=self.gpu_id,
Byron Hsu's avatar
Byron Hsu committed
912
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
913
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
914
915
916
                max_total_num_tokens=self.max_total_num_tokens,
                decode_tp_size=self.server_args.disaggregation_decode_tp,
                decode_dp_size=self.server_args.disaggregation_decode_dp,
917
                scheduler=self,
Byron Hsu's avatar
Byron Hsu committed
918
919
920
                pp_rank=self.pp_rank,
                pp_size=self.pp_size,
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
921
922
            )
            # The prefill requests that are in the middle of kv sending
923
            self.disagg_prefill_inflight_queue: List[Req] = []
Byron Hsu's avatar
Byron Hsu committed
924

925
926
927
928
929
930
931
932
933
934
935
936
937
    def init_overlap(self):
        if not self.enable_overlap:
            return

        self.forward_stream: CudaStream = torch.get_device_module(self.device).Stream()
        self.forward_stream_ctx: CudaStreamContext = torch.get_device_module(
            self.device
        ).stream(self.forward_stream)
        self.copy_stream: CudaStream = torch.get_device_module(self.device).Stream()
        self.copy_stream_ctx: CudaStreamContext = torch.get_device_module(
            self.device
        ).stream(self.copy_stream)

938
939
940
        self.future_map = FutureMap(
            self.max_running_requests, self.device, self.spec_algorithm
        )
941
942
943
944
945
946
947
948
949
950
951
952
        self.batch_record_buf = [None] * 2
        self.batch_record_ct = 0

    def record_batch_in_overlap(self, model_worker_batch: ModelWorkerBatch):
        # FIXME(lsyin): hacky way to keep a reference to avoid GPU tensors being freed by torch GC
        # NOTE: More Reliable: record all tensors into the forward stream
        # NOTE: - for all future tensors, we shall always read from future map
        #       - for all non-future tensors (produced only by schedule stream),
        #       we shall keep its reference not being release during all the forwarding pass
        self.batch_record_ct = (self.batch_record_ct + 1) % 2
        self.batch_record_buf[self.batch_record_ct] = model_worker_batch

953
954
955
956
    def init_moe_config(self):
        if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
            initialize_moe_config(self.server_args)

957
    @DynamicGradMode()
958
    def event_loop_normal(self):
959
        """A normal scheduler loop."""
960
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
961
962
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
963

964
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
965
            self.cur_batch = batch
966
967
968
969

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
970
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
971
                # When the server is idle, do self-check and re-init some states
972
                self.self_check_during_idle()
973
974

            self.last_batch = batch
975

976
    @DynamicGradMode()
Lianmin Zheng's avatar
Lianmin Zheng committed
977
    def event_loop_overlap(self):
978
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
979
        self.result_queue: Deque[Tuple[ScheduleBatch, GenerationBatchResult]] = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
980
981
982
983
984
985
986

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
987

988
            batch_result = None
Lianmin Zheng's avatar
Lianmin Zheng committed
989
            if batch:
990
991
                batch_result = self.run_batch(batch)
                self.result_queue.append((batch.copy(), batch_result))
Lianmin Zheng's avatar
Lianmin Zheng committed
992
993

            if self.last_batch:
994
                # Process the results of the last batch
995
                tmp_batch, tmp_result = self.result_queue.popleft()
996
                self.process_batch_result(tmp_batch, tmp_result)
Lianmin Zheng's avatar
Lianmin Zheng committed
997
            elif batch is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
998
                # When the server is idle, do self-check and re-init some states
999
                self.self_check_during_idle()
Lianmin Zheng's avatar
Lianmin Zheng committed
1000

1001
            self.launch_batch_sample_if_needed(batch_result)
Lianmin Zheng's avatar
Lianmin Zheng committed
1002
1003
            self.last_batch = batch

1004
1005
1006
            if envs.SGLANG_ENABLE_RUNTIME_MEM_LEAK_CHECK.get():
                self._check_runtime_mem_leak()

1007
1008
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
1009
1010
1011
1012
1013
1014
1015
1016

        if self.recv_skipper is not None:
            last_forward_mode = (
                self.last_batch.forward_mode if self.last_batch is not None else None
            )
            if not self.recv_skipper.handle(last_forward_mode):
                return []

1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
        if self.pp_rank == 0:
            if self.attn_tp_rank == 0:
                recv_reqs = []

                while True:
                    try:
                        recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_req)

                while True:
                    try:
                        recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_rpc)
            else:
                recv_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1036
        else:
1037
            if self.attn_tp_rank == 0:
1038
                dp_offset = self.attn_dp_rank * self.attn_tp_size
1039
1040
1041
                recv_reqs = point_to_point_pyobj(
                    [],
                    self.pp_rank * self.tp_size + dp_offset,
1042
                    self.world_group.device_group,
1043
1044
1045
1046
1047
                    (self.pp_rank - 1) * self.tp_size + dp_offset,
                    self.pp_rank * self.tp_size + dp_offset,
                )
            else:
                recv_reqs = None
1048

fzyzcjy's avatar
fzyzcjy committed
1049
1050
1051
        if self.input_blocker is not None:
            recv_reqs = self.input_blocker.handle(recv_reqs)

1052
1053
1054
1055
1056
1057
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
1058
1059
1060
1061
1062
1063
1064
                        req,
                        (
                            TokenizedGenerateReqInput,
                            TokenizedEmbeddingReqInput,
                            BatchTokenizedGenerateReqInput,
                            BatchTokenizedEmbeddingReqInput,
                        ),
1065
1066
1067
1068
1069
1070
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
1071
1072
1073
1074
1075
1076
1077
                        req,
                        (
                            TokenizedGenerateReqInput,
                            TokenizedEmbeddingReqInput,
                            BatchTokenizedGenerateReqInput,
                            BatchTokenizedEmbeddingReqInput,
                        ),
1078
1079
1080
1081
1082
1083
1084
1085
1086
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                work_reqs = broadcast_pyobj(
                    work_reqs,
1087
                    self.attn_tp_group.rank,
1088
                    self.attn_tp_cpu_group,
1089
                    src=self.attn_tp_group.ranks[0],
1090
1091
1092
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
1093
1094
1095
1096
                    control_reqs,
                    self.tp_group.rank,
                    self.tp_cpu_group,
                    src=self.tp_group.ranks[0],
1097
1098
1099
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
1100
1101
1102
1103
1104
1105
            recv_reqs = broadcast_pyobj(
                recv_reqs,
                self.tp_group.rank,
                self.tp_cpu_group,
                src=self.tp_group.ranks[0],
            )
1106

1107
1108
1109
1110
1111
1112
1113
        if self.enable_trace:
            for req in recv_reqs:
                if isinstance(
                    req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                ):
                    trace_set_proc_propagate_context(req.rid, req.trace_context)
                    trace_slice_start("", req.rid, anonymous=True)
1114

1115
1116
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
1117
    def process_input_requests(self, recv_reqs: List):
1118
        for recv_req in recv_reqs:
1119
1120
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
1121
1122
1123
                self.chunked_req is not None
                or not self.running_batch.is_empty()
                or len(self.offload_tags) > 0
1124
1125
1126
1127
            ):
                self.return_health_check_ct += 1
                continue

1128
            output = self._request_dispatcher(recv_req)
1129
            if output is not None:
1130
1131
1132
1133
                if isinstance(output, RpcReqOutput):
                    if self.recv_from_rpc is not None:
                        self.recv_from_rpc.send_pyobj(output)
                else:
1134
                    self.send_to_tokenizer.send_output(output, recv_req)
1135

1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
    def init_req_max_new_tokens(self, req):
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
            self.max_req_len - len(req.origin_input_ids) - 1,
        )

1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
    def _process_and_broadcast_mm_inputs(
        self,
        raw_mm_inputs: Optional[dict],
    ):
        """Materialize MultimodalInputs once on the entry rank and broadcast to others.

        Entry rank:
        - constructs MultimodalInputs.from_dict(raw_mm_inputs) once
        - broadcasts to other ranks in self.cpu_group (if world_size > 1)

        Non-entry ranks:
        - receive the object via broadcast (if world_size > 1)
        - otherwise (single-rank / no group) fall back to local from_dict

        Returns:
            MultimodalInputs | None
        """
        if raw_mm_inputs is None:
            return None

        group_world_size = 1
        try:
            if (
                torch.distributed.is_available()
                and torch.distributed.is_initialized()
                and self.cpu_group is not None
            ):
                group_world_size = torch.distributed.get_world_size(
                    group=self.cpu_group
                )
        except Exception as e:
            logger.warning(
                f"Failed to get world size in mm_inputs handling with {e}, fallback to 1."
            )

        # In case tp size > 1, all the Scheduler TP ranks runs the duplicated computing
        # process in CPU which occupies the main thread CPU cycle. This computing logic
        # merely needs to be run on TP0 and be broadcast to other TP ranks.
        # Since the Scheduler is single-threaded, any large CPU cost will impact
        # handling of other messages. For example, CPU hits 99.9% can significantly
        # increase the CUDA kernel launch time.
        if self.is_entry_rank:
            # Only the entry rank materializes once from dict.
            image_inputs = MultimodalInputs.from_dict(raw_mm_inputs)
            # Broadcast to other TP ranks (use src=0 within the group).
            if group_world_size > 1:
                obj_list = [image_inputs]
                torch.distributed.broadcast_object_list(
                    obj_list, src=0, group=self.cpu_group
                )
                image_inputs = obj_list[0]
        else:
            # Non-entry ranks: receive if group size > 1; otherwise materialize locally.
            if group_world_size > 1:
                obj_list = [None]
                torch.distributed.broadcast_object_list(
                    obj_list, src=0, group=self.cpu_group
                )
                image_inputs = obj_list[0]
            else:
                image_inputs = MultimodalInputs.from_dict(raw_mm_inputs)

        return image_inputs

1210
1211
1212
1213
    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
1214
        # Create a new request
1215
1216
1217
1218
1219
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
1220
1221
1222
1223
1224
1225
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

1226
1227
1228
1229
            if recv_req.bootstrap_port is None:
                # Use default bootstrap port
                recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port

1230
1231
1232
1233
1234
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
1235
1236
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
1237
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
1238
                stream=recv_req.stream,
1239
                lora_id=recv_req.lora_id,
Rin Intachuen's avatar
Rin Intachuen committed
1240
                input_embeds=recv_req.input_embeds,
Lianmin Zheng's avatar
Lianmin Zheng committed
1241
                custom_logit_processor=recv_req.custom_logit_processor,
1242
                return_hidden_states=recv_req.return_hidden_states,
1243
                eos_token_ids=self.model_config.hf_eos_token_id,
1244
                bootstrap_host=recv_req.bootstrap_host,
1245
                bootstrap_port=recv_req.bootstrap_port,
1246
                bootstrap_room=recv_req.bootstrap_room,
1247
                disagg_mode=self.disaggregation_mode,
1248
                data_parallel_rank=recv_req.data_parallel_rank,
1249
                vocab_size=self.model_config.vocab_size,
1250
                priority=recv_req.priority,
1251
1252
1253
                metrics_collector=(
                    self.metrics_collector if self.enable_metrics else None
                ),
1254
                http_worker_ipc=recv_req.http_worker_ipc,
1255
1256
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
1257

1258
1259
1260
            if self.disaggregation_mode != DisaggregationMode.NULL:
                # Invalid request for disaggregated mode
                if recv_req.bootstrap_room is None:
1261
                    error_msg = (
1262
1263
1264
                        f"Invalid request: Disaggregated request received without "
                        f"boostrap room id. {req.rid=}"
                    )
1265
                    logger.error(error_msg)
1266
                    prepare_abort(req, error_msg, status_code=HTTPStatus.BAD_REQUEST)
1267
1268
1269
                    self.stream_output([req], req.return_logprob)
                    return

1270
1271
1272
1273
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
1274
                req.set_finish_with_abort(
1275
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
1276
                )
1277
                self.init_req_max_new_tokens(req)
1278
                self._add_request_to_queue(req)
1279
1280
                return
        else:
1281
1282
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
1283
            req = session.create_req(recv_req, self.tokenizer)
1284
            if isinstance(req.finished_reason, FINISH_ABORT):
1285
                self.init_req_max_new_tokens(req)
1286
                self._add_request_to_queue(req)
1287
                return
1288

1289
        # Handle multimodal inputs
Mick's avatar
Mick committed
1290
        if recv_req.mm_inputs is not None:
1291
1292
1293
            image_inputs = self._process_and_broadcast_mm_inputs(recv_req.mm_inputs)

            # The following steps are already fast, execute locally on each rank.
1294
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
1295
            req.origin_input_ids = self.pad_input_ids_func(
1296
                req.origin_input_ids, image_inputs
1297
            )
1298
            req.extend_image_inputs(image_inputs)
1299

1300
            if len(req.origin_input_ids) >= self.max_req_input_len:
1301
1302
1303
1304
1305
                req.set_finish_with_abort(
                    error_msg=(
                        "Multimodal prompt is too long after expanding multimodal tokens. "
                        f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                    )
1306
                )
1307
                self.init_req_max_new_tokens(req)
1308
                self._add_request_to_queue(req)
1309
1310
                return

1311
1312
1313
        # initialize before returning
        self.init_req_max_new_tokens(req)

1314
        # Validate prompt length
1315
1316
1317
1318
1319
1320
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
1321
            req.set_finish_with_abort(error_msg)
1322
            self._add_request_to_queue(req)
1323
            return
1324

1325
        # Copy more attributes
1326
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
1327
            # By default, only return the logprobs for output tokens
1328
1329
1330
1331
1332
1333
1334
            # For prefill-only requests with logprob_start_len == -1, set logprob_start_len beyond input sequence
            # to skip input logprob computation entirely
            if req.is_prefill_only:
                req.logprob_start_len = len(req.origin_input_ids)
            else:
                # TODO: For text generation, evaluate setting logprob_start_len to len(req.origin_input_ids) as well
                req.logprob_start_len = len(req.origin_input_ids) - 1
1335
1336
1337
        else:
            req.logprob_start_len = recv_req.logprob_start_len

1338
1339
1340
        if not req.is_prefill_only and req.logprob_start_len >= len(
            req.origin_input_ids
        ):
1341
            error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
1342
            req.logprob_start_len = len(req.origin_input_ids) - 1
1343
            req.set_finish_with_abort(error_msg)
1344
1345
1346
            self._add_request_to_queue(req)
            return

1347
1348
1349
1350
1351
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
1352
            or req.sampling_params.ebnf is not None
1353
            or req.sampling_params.structural_tag is not None
1354
        ):
1355
1356
1357
            if self.grammar_backend is None:
                error_msg = "Grammar-based generation (json_schema, regex, ebnf, structural_tag) is not supported when the server is launched with --grammar-backend none"
                req.set_finish_with_abort(error_msg)
1358
            else:
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
                if req.sampling_params.json_schema is not None:
                    key = ("json", req.sampling_params.json_schema)
                elif req.sampling_params.regex is not None:
                    key = ("regex", req.sampling_params.regex)
                elif req.sampling_params.ebnf is not None:
                    key = ("ebnf", req.sampling_params.ebnf)
                elif req.sampling_params.structural_tag:
                    key = ("structural_tag", req.sampling_params.structural_tag)

                value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
                req.grammar = value

                if not cache_hit:
                    req.grammar_key = key
                    add_to_grammar_queue = True
                else:
                    if value is INVALID_GRAMMAR_OBJ:  # We hit a cached invalid grammar.
                        error_msg = f"Invalid grammar request with cache hit: {key=}"
                        req.set_finish_with_abort(error_msg)
1378
1379

        if add_to_grammar_queue:
1380
1381
            self.grammar_queue.append(req)
        else:
1382
1383
            self._add_request_to_queue(req)

1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
    def handle_batch_generate_request(
        self,
        recv_req: BatchTokenizedGenerateReqInput,
    ):
        """Handle optimized batch generate request."""
        logger.debug(f"Processing batch generate request with {len(recv_req)} requests")

        # Process each request in the batch
        for tokenized_req in recv_req:
            self.handle_generate_request(tokenized_req)

1395
1396
1397
    def _prefetch_kvcache(self, req: Req):
        if self.enable_hicache_storage:
            req.init_next_round_input(self.tree_cache)
1398
1399
1400
1401
1402
            if req.last_node.backuped:
                # only to initiate the prefetch if the last node is backuped
                # otherwise, the allocated GPU memory must be locked for integrity
                last_hash = req.last_host_node.get_last_hash_value()
                matched_len = len(req.prefix_indices) + req.host_hit_length
1403
                new_input_tokens = req.fill_ids[matched_len:]
1404
1405
1406
1407
1408
1409

                prefix_keys = (
                    req.last_node.get_prefix_hash_values(req.last_node.parent)
                    if self.tree_cache.hicache_storage_pass_prefix_keys
                    else None
                )
1410
                self.tree_cache.prefetch_from_storage(
1411
1412
1413
1414
1415
                    req.rid,
                    req.last_host_node,
                    new_input_tokens,
                    last_hash,
                    prefix_keys,
1416
1417
                )

1418
1419
    def _add_request_to_queue(self, req: Req, is_retracted: bool = False):
        if self.disaggregation_mode == DisaggregationMode.NULL:
1420
1421
            if not self._set_or_validate_priority(req):
                return
1422
1423
1424
1425
1426
            if self._abort_on_queued_limit(req):
                return
            self._prefetch_kvcache(req)
            self.waiting_queue.append(req)
            req.time_stats.wait_queue_entry_time = time.perf_counter()
1427
            trace_slice_end(RequestStage.REQUEST_PROCESS, req.rid, auto_next_anon=True)
1428
1429
1430
1431
        elif self.disaggregation_mode == DisaggregationMode.PREFILL:
            self._prefetch_kvcache(req)
            self.disagg_prefill_bootstrap_queue.add(
                req, self.model_config.num_key_value_heads
Byron Hsu's avatar
Byron Hsu committed
1432
            )
1433
            req.time_stats.prefill_bootstrap_queue_entry_time = time.perf_counter()
1434
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
1435
1436
1437
            self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted)
            if not is_retracted:
                req.time_stats.decode_prealloc_queue_entry_time = time.perf_counter()
Byron Hsu's avatar
Byron Hsu committed
1438
        else:
1439
            raise ValueError(f"Invalid {self.disaggregation_mode=}")
1440

1441
    def _set_or_validate_priority(self, req: Req) -> bool:
1442
1443
1444
1445
1446
1447
        """Set the default priority value, or abort the request based on the priority scheduling mode."""
        if self.enable_priority_scheduling and req.priority is None:
            if self.schedule_low_priority_values_first:
                req.priority = sys.maxsize
            else:
                req.priority = -sys.maxsize - 1
1448
1449
1450
1451
1452
        elif (
            not self.enable_priority_scheduling
            and req.priority is not None
            and self.abort_on_priority_when_disabled
        ):
1453
1454
1455
1456
1457
1458
            abort_req = AbortReq(
                finished_reason={
                    "type": "abort",
                    "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
                    "message": "Using priority is disabled for this server. Please send a new request without a priority.",
                },
1459
                rid=req.rid,
1460
            )
1461
            self.send_to_tokenizer.send_output(abort_req, req)
1462
1463
            return False
        return True
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483

    def _abort_on_queued_limit(self, recv_req: Req) -> bool:
        """Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
        if (
            self.max_queued_requests is None
            or len(self.waiting_queue) + 1 <= self.max_queued_requests
        ):
            return False

        # Reject the incoming request by default.
        req_to_abort = recv_req
        message = "The request queue is full."
        if self.enable_priority_scheduling:
            # With priority scheduling, consider aboritng an existing request based on the priority.
            # direction = 1  => smaller number = higher priority; -1 => larger number = higher priority.
            # max(...) + (direction * priority, queue_time_start) picks the least-preferred request.
            # Tie: later queue_time_start (newer) is evicted first. Preempt only if strictly better.
            direction = 1 if self.schedule_low_priority_values_first else -1
            key_fn = lambda item: (
                direction * item[1].priority,
1484
                item[1].time_stats.wait_queue_entry_time,
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
            )
            idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
            abort_existing_req = (
                direction * recv_req.priority < direction * candidate_req.priority
            )
            if abort_existing_req:
                self.waiting_queue.pop(idx)
                req_to_abort = candidate_req
                message = "The request is aborted by a higher priority request."

1495
        self.send_to_tokenizer.send_output(
1496
1497
1498
1499
1500
1501
            AbortReq(
                finished_reason={
                    "type": "abort",
                    "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
                    "message": message,
                },
1502
                rid=req_to_abort.rid,
1503
1504
            ),
            req_to_abort,
1505
1506
        )
        return req_to_abort.rid == recv_req.rid
1507
1508
1509

    def handle_embedding_request(
        self,
1510
        recv_req: TokenizedEmbeddingReqInput,
1511
1512
1513
1514
1515
1516
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
woodx's avatar
woodx committed
1517
            token_type_ids=recv_req.token_type_ids,
1518
            priority=recv_req.priority,
1519
            dimensions=recv_req.dimensions,
1520
            http_worker_ipc=recv_req.http_worker_ipc,
1521
1522
1523
        )
        req.tokenizer = self.tokenizer

1524
1525
        # Handle multimodal inputs
        if recv_req.image_inputs is not None:
1526
            image_inputs = self._process_and_broadcast_mm_inputs(recv_req.image_inputs)
1527
1528
1529
1530
1531
1532
1533
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
            req.origin_input_ids = self.pad_input_ids_func(
                req.origin_input_ids, image_inputs
            )
            req.extend_image_inputs(image_inputs)

            if len(req.origin_input_ids) >= self.max_req_input_len:
1534
1535
1536
1537
1538
                req.set_finish_with_abort(
                    error_msg=(
                        "Multimodal prompt is too long after expanding multimodal tokens. "
                        f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                    )
1539
                )
1540
                self._add_request_to_queue(req)
1541
1542
                return

1543
        # Validate prompts length
1544
        error_msg = validate_input_length(
1545
1546
1547
1548
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
1549
        if error_msg:
1550
            self._add_request_to_queue(req)
1551
            return
1552

1553
1554
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
1555
        self._add_request_to_queue(req)
1556

1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
    def handle_batch_embedding_request(
        self,
        recv_req: BatchTokenizedEmbeddingReqInput,
    ):
        """Handle optimized batch embedding request."""
        logger.debug(
            f"Processing batch embedding request with {len(recv_req)} requests"
        )

        # Process each request in the batch
        for tokenized_req in recv_req:
            self.handle_embedding_request(tokenized_req)

Hanming Lu's avatar
Hanming Lu committed
1570
1571
1572
1573
1574
1575
1576
    def _get_token_info(self):
        available_size = self.token_to_kv_pool_allocator.available_size()
        evictable_size = self.tree_cache.evictable_size()
        num_used = self.max_total_num_tokens - (available_size + evictable_size)
        token_usage = num_used / self.max_total_num_tokens
        return num_used, token_usage, available_size, evictable_size

1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
    def _get_mamba_token_info(self):
        is_radix_tree = isinstance(self.tree_cache, MambaRadixCache)
        full_available_size = self.token_to_kv_pool_allocator.available_size()
        full_evictable_size = (
            self.tree_cache.full_evictable_size() if is_radix_tree else 0
        )
        mamba_available_size = self.req_to_token_pool.mamba_pool.available_size()
        mamba_evictable_size = (
            self.tree_cache.mamba_evictable_size() if is_radix_tree else 0
        )
        full_num_used = self.token_to_kv_pool_allocator.size - (
            full_available_size + full_evictable_size
        )
        mamba_num_used = self.req_to_token_pool.mamba_pool.size - (
            mamba_available_size + mamba_evictable_size
        )
        full_token_usage = full_num_used / self.token_to_kv_pool_allocator.size
        mamba_usage = mamba_num_used / self.req_to_token_pool.mamba_pool.size
        return (
            full_num_used,
            mamba_num_used,
            full_token_usage,
            mamba_usage,
            full_available_size,
            full_evictable_size,
            mamba_available_size,
            mamba_evictable_size,
        )

Hanming Lu's avatar
Hanming Lu committed
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
    def _get_swa_token_info(self):
        full_available_size = self.token_to_kv_pool_allocator.full_available_size()
        full_evictable_size = self.tree_cache.full_evictable_size()
        swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
        swa_evictable_size = self.tree_cache.swa_evictable_size()
        full_num_used = self.full_tokens_per_layer - (
            full_available_size + full_evictable_size
        )
        swa_num_used = self.swa_tokens_per_layer - (
            swa_available_size + swa_evictable_size
        )
        full_token_usage = full_num_used / self.full_tokens_per_layer
        swa_token_usage = swa_num_used / self.swa_tokens_per_layer
        return (
            full_num_used,
            swa_num_used,
            full_token_usage,
            swa_token_usage,
            full_available_size,
            full_evictable_size,
            swa_available_size,
            swa_evictable_size,
        )

1630
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1631
        # Merge the prefill batch into the running batch
1632
1633
1634
1635
1636
        chunked_req_to_exclude = set()
        if self.chunked_req:
            # Move the chunked request out of the batch so that we can merge
            # only finished requests to running_batch.
            chunked_req_to_exclude.add(self.chunked_req)
1637
            self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
1638
            # chunked request keeps its rid but will get a new req_pool_idx
1639
            if self.tp_worker.model_runner.mambaish_config is not None:
Yi Zhang's avatar
Yi Zhang committed
1640
1641
1642
1643
1644
                self.req_to_token_pool.free(
                    self.chunked_req.req_pool_idx, free_mamba_cache=False
                )
            else:
                self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
Lianmin Zheng's avatar
Lianmin Zheng committed
1645
        if self.last_batch and self.last_batch.forward_mode.is_extend():
1646
1647
1648
1649
            if self.last_batch.chunked_req is not None:
                # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
                # We need to discard it.
                chunked_req_to_exclude.add(self.last_batch.chunked_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1650

1651
            # Filter batch
1652
            last_bs = self.last_batch.batch_size()
1653
1654
1655
            self.last_batch.filter_batch(
                chunked_req_to_exclude=list(chunked_req_to_exclude)
            )
1656
            if self.last_batch.batch_size() < last_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1657
                self.running_batch.batch_is_full = False
1658

1659
1660
1661
            # Merge the new batch into the running batch.
            # For prefill-only batch, we can avoid going through decoding step.
            if not self.last_batch.is_empty() and not self.last_batch.is_prefill_only:
Lianmin Zheng's avatar
Lianmin Zheng committed
1662
                if self.running_batch.is_empty():
1663
1664
                    self.running_batch = self.last_batch
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1665
                    # Merge running_batch with prefill batch
1666
                    self.running_batch.merge_batch(self.last_batch)
1667

1668
        new_batch = self.get_new_batch_prefill()
1669

1670
1671
1672
1673
1674
        need_dp_attn_preparation = require_mlp_sync(self.server_args)

        if need_dp_attn_preparation and not self.spec_algorithm.is_none():
            # In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
            # We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
1675
            new_batch = self.prepare_mlp_sync_batch(new_batch)
1676
1677
1678
            need_dp_attn_preparation = new_batch is None

        if new_batch is not None:
1679
1680
1681
1682
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
1683
            if not self.running_batch.is_empty():
1684
                self.running_batch = self.update_running_batch(self.running_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1685
1686
1687
                ret = self.running_batch if not self.running_batch.is_empty() else None
            else:
                ret = None
1688

1689
1690
        # Handle DP attention
        if need_dp_attn_preparation:
1691
            ret = self.prepare_mlp_sync_batch(ret)
1692

1693
1694
1695
1696
        if ret:
            attrs = {"bid": hex(id(ret)), "batch_size": ret.batch_size()}
            trace_event_batch("schedule", ret.reqs, attrs=attrs)

1697
        return ret
1698

1699
    def get_num_allocatable_reqs(self, running_bs):
1700
        res = get_global_server_args().pp_max_micro_batch_size - running_bs
1701
1702
1703
1704
        if self.pp_size > 1:
            res = min(res, self.req_to_token_pool.available_size())
        return res

Lianmin Zheng's avatar
Lianmin Zheng committed
1705
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
1706
        # Check if the grammar is ready in the grammar queue
1707
        if self.grammar_queue:
1708
            self.move_ready_grammar_requests()
1709

1710
1711
1712
1713
        if self.try_preemption:
            # Reset batch_is_full to try preemption with a prefill adder.
            self.running_batch.batch_is_full = False

Lianmin Zheng's avatar
Lianmin Zheng committed
1714
1715
        # Handle the cases where prefill is not allowed
        if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1716
            self.running_batch.batch_is_full or len(self.waiting_queue) == 0
1717
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1718
1719
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1720
        running_bs = len(self.running_batch.reqs)
1721
        # Ignore the check if self.chunked_req is not None.
1722
1723
1724
1725
        # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
        # as the space for the chunked request has just been released.
        # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
        # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
1726
1727
1728
1729
1730
        if (
            self.get_num_allocatable_reqs(running_bs) <= 0
            and not self.chunked_req
            and not self.try_preemption
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1731
            self.running_batch.batch_is_full = True
1732
1733
            return None

1734
        if self.enable_hierarchical_cache:
1735
            self.tree_cache.check_hicache_events()
1736

1737
        # Get priority queue
1738
        self.policy.calc_priority(self.waiting_queue)
1739

1740
1741
1742
1743
1744
1745
        if TEST_RETRACT and running_bs > TEST_RETRACT_NO_PREFILL_BS:
            # If we are testing retraction and the running batch size exceeds
            # TEST_RETRACT_NO_PREFILL_BS, we skip the prefill to keep the requests
            # in the waiting queue.
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1746
        # Prefill policy
1747
        adder = PrefillAdder(
1748
            self.page_size,
1749
            self.tree_cache,
1750
            self.token_to_kv_pool_allocator,
1751
1752
1753
1754
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1755
            running_bs if self.is_mixed_chunk else 0,
1756
            self.priority_scheduling_preemption_threshold,
1757
1758
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1759
        if self.chunked_req is not None:
1760
1761
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1762

1763
        if self.enable_lora:
1764
            lora_set = set([req.lora_id for req in self.running_batch.reqs])
Lianmin Zheng's avatar
Lianmin Zheng committed
1765

1766
        # Get requests from the waiting queue to a new prefill batch
1767
        for req in self.waiting_queue:
1768
1769
1770
1771
1772

            if self.enable_lora and not self.tp_worker.can_run_lora_batch(
                lora_set
                | set([req.lora_id for req in adder.can_run_list])
                | set([req.lora_id])
1773
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1774
                self.running_batch.batch_is_full = True
1775
1776
                break

1777
            running_bs = len(self.running_batch.reqs) - len(adder.preempt_list)
1778
            if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
Lianmin Zheng's avatar
Lianmin Zheng committed
1779
                self.running_batch.batch_is_full = True
Byron Hsu's avatar
Byron Hsu committed
1780
1781
1782
1783
1784
            if self.disaggregation_mode == DisaggregationMode.PREFILL:
                # In prefill mode, prealloc queue and transfer queue can also take memory,
                # so we need to check if the available size for the actual available size.
                if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
                    self.running_batch.batch_is_full = True
1785
1786
1787
1788
1789

            if self.running_batch.batch_is_full:
                if not self.try_preemption:
                    break
                if not adder.preempt_to_schedule(req, self.server_args):
Byron Hsu's avatar
Byron Hsu committed
1790
1791
                    break

1792
            if self.enable_hicache_storage:
pansicheng's avatar
pansicheng committed
1793
1794
1795
1796
                prefetch_done = self.tree_cache.check_prefetch_progress(req.rid)
                if not prefetch_done:
                    # skip staging requests that are ongoing prefetch
                    continue
1797

1798
            req.init_next_round_input(self.tree_cache)
1799
1800
1801
1802
1803
            res = adder.add_one_req(
                req,
                has_chunked_req=(self.chunked_req is not None),
                truncation_align_size=self.truncation_align_size,
            )
1804

1805
1806
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1807
1808
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
Lianmin Zheng's avatar
Lianmin Zheng committed
1809
1810
                        self.running_batch.batch_is_full = len(
                            adder.can_run_list
1811
                        ) > 0 or (not self.running_batch.is_empty())
1812
                    else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1813
                        self.running_batch.batch_is_full = True
1814
1815
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1816
        # Update waiting queue
1817
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1818
1819
        if len(can_run_list) == 0:
            return None
1820
1821
1822
1823

        if self.enable_metrics:
            # only record queue time when enable_metrics is True to avoid overhead
            for req in can_run_list:
1824
                req.add_latency(RequestStage.PREFILL_WAITING)
1825

Lianmin Zheng's avatar
Lianmin Zheng committed
1826
1827
1828
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1829
        if adder.preempt_list:
1830
1831
            for req in adder.preempt_list:
                self._add_request_to_queue(req)
1832

1833
1834
1835
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1836

1837
1838
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1839

1840
        # Print stats
1841
        if self.current_scheduler_metrics_enabled():
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
            self.log_prefill_stats(adder, can_run_list, running_bs, 0)

        for req in can_run_list:
            if req.time_stats.forward_entry_time == 0:
                # Avoid update chunked request many times
                req.time_stats.forward_entry_time = time.perf_counter()
                if self.enable_metrics:
                    self.metrics_collector.observe_queue_time(
                        req.time_stats.get_queueing_time(),
                    )
1852

Lianmin Zheng's avatar
Lianmin Zheng committed
1853
        # Create a new batch
1854
1855
1856
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
1857
            self.token_to_kv_pool_allocator,
1858
            self.tree_cache,
1859
            self.model_config,
1860
            self.enable_overlap,
1861
            self.spec_algorithm,
1862
            chunked_req=self.chunked_req,
1863
        )
1864
1865
        if self.enable_hierarchical_cache:
            # todo (zhiqiang): disable cuda graph execution if hicache loading triggered
1866
1867
1868
            new_batch.hicache_consumer_index = (
                self.tree_cache.ready_to_load_host_cache()
            )
1869

1870
        new_batch.prepare_for_extend()
1871

Lianmin Zheng's avatar
Lianmin Zheng committed
1872
        # Mixed-style chunked prefill
1873
1874
        if (
            self.is_mixed_chunk
Lianmin Zheng's avatar
Lianmin Zheng committed
1875
            and not self.running_batch.is_empty()
1876
1877
1878
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1879
1880
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1881
                self.running_batch.prepare_for_decode()
1882
1883
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1884
1885
1886
            self.running_batch = ScheduleBatch(
                reqs=[], batch_is_full=self.running_batch.batch_is_full
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1887
1888
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1889
1890
1891

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1892
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1893
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1894
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1895

1896
1897
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1898
1899
            batch.batch_is_full = False
            return batch
1900

Lianmin Zheng's avatar
Lianmin Zheng committed
1901
        # Check if decode out of memory
1902
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1903
            TEST_RETRACT and self.forward_ct % TEST_RETRACT_INTERVAL == 0
1904
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1905
            old_ratio = self.new_token_ratio
1906
1907
1908
            retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
                self.server_args
            )
1909
            self.num_retracted_reqs = len(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1910
            self.new_token_ratio = new_token_ratio
1911
1912
1913
1914
1915
            for req in reqs_to_abort:
                abort_reason: FINISH_ABORT = req.to_finish
                self.send_to_tokenizer.send_output(
                    AbortReq(abort_message=abort_reason.message, rid=req.rid), req
                )
1916

Lianmin Zheng's avatar
Lianmin Zheng committed
1917
            logger.info(
1918
                "KV cache pool is full. Retract requests. "
1919
1920
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
Lianmin Zheng's avatar
Lianmin Zheng committed
1921
            )
1922

1923
1924
            for req in retracted_reqs:
                self._add_request_to_queue(req, is_retracted=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
1925
1926
        else:
            self.new_token_ratio = max(
1927
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1928
1929
1930
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1931
        if batch.batch_size() < initial_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1932
            batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1933
1934

        # Update batch tensors
1935
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1936
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1937

1938
1939
1940
1941
1942
1943
    # placeholder for override
    def update_cache_from_scheduler(
        self, schedule_batch: ScheduleBatch, batch_result: GenerationBatchResult
    ):
        pass

1944
1945
1946
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1947
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1948
1949
        self.forward_ct += 1

1950
1951
        # Whether to run the profiler
        self._profile_batch_predicate(batch)
1952
1953
1954
1955
        if self.forward_sleep_time is not None:
            logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
            time.sleep(self.forward_sleep_time)

1956
1957
1958
1959
1960
1961
        # Capture prefill start time for EXTEND mode
        if batch.forward_mode == ForwardMode.EXTEND:
            current_time = time.perf_counter()
            for req in batch.reqs:
                req.time_stats.prefill_start_time = current_time

1962
        # Run forward
1963
        if self.is_generation:
1964
1965
            batch_or_worker_batch = batch

1966
            if self.enable_overlap or self.spec_algorithm.is_none():
1967
1968
                # FIXME(lsyin): remove this if and finally unify the abstraction
                batch_or_worker_batch = batch.get_model_worker_batch()
1969

1970
1971
1972
1973
1974
1975
1976
            if self.enable_overlap:
                # FIXME: remove this assert
                assert isinstance(batch_or_worker_batch, ModelWorkerBatch)
                model_worker_batch = batch_or_worker_batch
                self.record_batch_in_overlap(model_worker_batch)

                # Sampling info will be modified during forward
1977
                model_worker_batch.sampling_info = (
1978
1979
1980
1981
                    model_worker_batch.sampling_info.copy_for_forward()
                )

                bs = len(model_worker_batch.seq_lens)
1982
                future_indices = self.future_map.alloc_future_indices(bs)
1983
1984
1985
1986
1987

                with self.forward_stream_ctx:
                    self.forward_stream.wait_stream(self.default_stream)
                    self.future_map.resolve_future(model_worker_batch)
                    batch_result = self.model_worker.forward_batch_generation(
1988
                        model_worker_batch
1989
1990
1991
1992
1993
                    )
                    # FIXME(lsyin): maybe move this to forward_batch_generation
                    batch_result.copy_done = torch.get_device_module(
                        self.device
                    ).Event()
1994
                    if batch_result.delay_sample_func is None:
1995
                        self.future_map.store_to_map(future_indices, batch_result)
1996
1997
                        batch_result.copy_to_cpu()
                    else:
1998
                        batch_result.future_indices = future_indices
1999
2000

                # FIXME(lsyin): move this assignment elsewhere
2001
2002
2003
2004
2005
2006
2007
2008
                future_indices_or_next_token_ids = -future_indices.indices

                if batch.is_v2_eagle:
                    # FIXME(lsyin): tmp code for eagle v2
                    # We only keep future indices for next draft input

                    batch.spec_info = batch_result.next_draft_input
                    batch.spec_info.future_indices = future_indices
linhai1's avatar
linhai1 committed
2009
                    batch.sampling_info.is_all_greedy = True #nhb
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
                    # batch.spec_info = EagleDraftInput(
                    #     future_indices=future_indices,
                    #     verify_done=batch_result.next_draft_input.verify_done,
                    #     # FIXME(lsyin): remove the allocate_lens in EagleDraftInput
                    #     allocate_lens=batch_result.next_draft_input.allocate_lens,
                    # )

                    # The future value, usually for next batch preparation
                    # Current implementation strictly synchronizes the seq_lens
                    batch.seq_lens = batch_result.next_draft_input.new_seq_lens
2020
2021
2022
            elif self.enable_pdmux and batch.forward_mode.is_split_prefill():
                batch_result = self.tp_worker.forward_batch_split_prefill(batch)
                future_indices_or_next_token_ids = batch_result.next_token_ids
2023
2024
2025
2026
            else:
                batch_result = self.model_worker.forward_batch_generation(
                    batch_or_worker_batch
                )
2027
                future_indices_or_next_token_ids = batch_result.next_token_ids
2028
                self.update_cache_from_scheduler(batch, batch_result)
2029

2030
            # NOTE: future_indices_or_next_token_ids is used in ScheduleBatch,
2031
2032
2033
            #       which can probably be replaced by future_indices later [TODO(lsyin)].
            #       we shall still keep the original outputs, e.g. next_token_ids
            #       in the GenerationBatchOutput for processing after copy_done.
2034
            batch.output_ids = future_indices_or_next_token_ids
2035

2036
2037
2038
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
2039
            if batch.return_logprob or self.spec_algorithm.is_eagle():
2040
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
2041
2042
            else:
                extend_input_len_per_req = None
2043

2044
            if batch.return_logprob:
2045
2046
2047
2048
2049
2050
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_logprob_start_len_per_req = None

2051
2052
2053
            batch_result.extend_input_len_per_req = extend_input_len_per_req
            batch_result.extend_logprob_start_len_per_req = (
                extend_logprob_start_len_per_req
2054
            )
2055
            ret = batch_result
Lianmin Zheng's avatar
Lianmin Zheng committed
2056
2057
2058
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
2059
            ret = EmbeddingBatchResult(embeddings=embeddings)
2060
2061
2062
2063
2064
2065
2066

        # Capture prefill end time for EXTEND mode
        if batch.forward_mode == ForwardMode.EXTEND:
            current_time = time.perf_counter()
            for req in batch.reqs:
                req.time_stats.prefill_end_time = current_time

2067
        return ret
Chayenne's avatar
Chayenne committed
2068

2069
2070
    def launch_batch_sample_if_needed(
        self, batch_result: GenerationBatchResult
2071
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
2072
2073
2074
        # TODO(lsyin): make the delayed sample a default behavior after
        # unifying the forward_batch_generation interface (related to spec V2).
        if batch_result is None or batch_result.delay_sample_func is None:
2075
2076
2077
2078
            return

        with self.forward_stream_ctx:
            self.forward_stream.wait_stream(self.default_stream)
2079
2080
2081
2082
            _batch_result = batch_result.delay_sample_func()
            assert _batch_result is batch_result
            self.future_map.store_to_map(batch_result.future_indices, batch_result)
            batch_result.copy_to_cpu()
2083

2084
2085
2086
2087
2088
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
2089
        if batch.forward_mode.is_decode():
2090
            self.process_batch_result_decode(batch, result)
2091
            trace_slice_batch(RequestStage.DECODE_LOOP, batch.reqs)
2092

2093
        elif batch.forward_mode.is_extend():
2094
            self.process_batch_result_prefill(batch, result)
2095

2096
2097
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
2098
2099
                if result.copy_done is not None:
                    result.copy_done.synchronize()
Lianmin Zheng's avatar
Lianmin Zheng committed
2100

2101
2102
2103
        self.maybe_send_health_check_signal()

    def maybe_send_health_check_signal(self):
2104
2105
2106
2107
2108
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
2109
            self.send_to_tokenizer.send_output(HealthCheckOutput())
2110

2111
2112
    def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
        return self.prepare_mlp_sync_batch_raw(
2113
2114
2115
            local_batch,
            dp_size=self.server_args.dp_size,
            attn_tp_size=self.attn_tp_size,
2116
            tp_group=self.tp_group,
2117
2118
2119
2120
            get_idle_batch=self.get_idle_batch,
            disable_cuda_graph=self.server_args.disable_cuda_graph,
            spec_algorithm=self.spec_algorithm,
            speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
2121
            require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
2122
            disable_overlap_schedule=self.server_args.disable_overlap_schedule,
2123
            offload_tags=self.offload_tags,
2124
2125
2126
        )

    @staticmethod
2127
    def prepare_mlp_sync_batch_raw(
2128
2129
2130
        local_batch: ScheduleBatch,
        dp_size,
        attn_tp_size: int,
2131
        tp_group,
2132
2133
2134
2135
        get_idle_batch,
        disable_cuda_graph: bool,
        spec_algorithm,
        speculative_num_draft_tokens,
2136
        require_mlp_tp_gather: bool,
2137
        disable_overlap_schedule: bool,
2138
        offload_tags: set[str],
2139
    ):
2140
2141
2142
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
2143
            num_tokens_for_logprob = 0
2144
2145
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
2146
            num_tokens_for_logprob = num_tokens
2147
2148
        else:
            num_tokens = local_batch.extend_num_tokens
2149
2150
            if local_batch.return_logprob:
                num_tokens_for_logprob = sum(
Lianmin Zheng's avatar
Lianmin Zheng committed
2151
2152
2153
                    # We should have at least 1 token for sample in every case.
                    max(extend_len - logprob_start_len, 1)
                    for logprob_start_len, extend_len in zip(
2154
2155
                        local_batch.extend_logprob_start_lens,
                        local_batch.extend_lens,
Lianmin Zheng's avatar
Lianmin Zheng committed
2156
                    )
2157
2158
2159
2160
                )
            else:
                # When return_logprob = False, only need last token per request
                num_tokens_for_logprob = local_batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
2161
2162
2163
2164
2165
2166
2167
2168
2169

        if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
            can_cuda_graph = 1
        else:
            can_cuda_graph = 0

        is_extend_in_batch = (
            local_batch.forward_mode.is_extend() if local_batch else False
        )
2170
2171

        tbo_preparer = TboDPAttentionPreparer()
2172
        if len(offload_tags) == 0 and disable_overlap_schedule:
2173
2174
2175
2176
2177
            group = tp_group.device_group
            device = tp_group.device
        else:
            group = tp_group.cpu_group
            device = "cpu"
2178

Lianmin Zheng's avatar
Lianmin Zheng committed
2179
2180
2181
2182
        local_info = torch.tensor(
            [
                num_tokens,
                can_cuda_graph,
2183
                num_tokens_for_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
2184
                is_extend_in_batch,
2185
2186
2187
                *tbo_preparer.prepare_all_gather(
                    local_batch,
                ),
Lianmin Zheng's avatar
Lianmin Zheng committed
2188
2189
            ],
            dtype=torch.int64,
2190
            device=device,
Lianmin Zheng's avatar
Lianmin Zheng committed
2191
2192
        )
        global_info = torch.empty(
2193
            (dp_size, attn_tp_size, 6),
Lianmin Zheng's avatar
Lianmin Zheng committed
2194
            dtype=torch.int64,
2195
            device=device,
Lianmin Zheng's avatar
Lianmin Zheng committed
2196
        )
2197
        torch.distributed.all_gather_into_tensor(
Lianmin Zheng's avatar
Lianmin Zheng committed
2198
2199
            global_info.flatten(),
            local_info,
2200
            group=group,
2201
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
2202
2203
2204
2205
        global_num_tokens = global_info[:, 0, 0].tolist()
        can_cuda_graph = min(global_info[:, 0, 1].tolist())
        global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
        is_extend_in_batch = global_info[:, 0, 3].tolist()
2206

2207
2208
2209
2210
        tbo_split_seq_index, global_forward_mode = tbo_preparer.compute_output(
            global_info[:, :, 4:6]
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
2211
        if local_batch is None and max(global_num_tokens) > 0:
2212
            local_batch = get_idle_batch()
2213
2214

        if local_batch is not None:
2215
            # TODO: handle the case when moe_dense_tp_size != 1
2216
            if not require_mlp_tp_gather:
2217
2218
2219
2220
2221
2222
2223
                local_batch.global_num_tokens = [num_tokens]
                local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
            else:
                local_batch.global_num_tokens = global_num_tokens
                local_batch.global_num_tokens_for_logprob = (
                    global_num_tokens_for_logprob
                )
2224
            local_batch.is_extend_in_batch = any(is_extend_in_batch)
2225
2226
            local_batch.tbo_split_seq_index = tbo_split_seq_index
            local_batch.global_forward_mode = global_forward_mode
2227

2228
            # Check forward mode for cuda graph
2229
            if not disable_cuda_graph:
Lianmin Zheng's avatar
Lianmin Zheng committed
2230
                local_batch.can_run_dp_cuda_graph = can_cuda_graph
2231

2232
        return local_batch
2233
2234
2235
2236
2237

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
2238
            self.token_to_kv_pool_allocator,
2239
2240
2241
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
2242
            self.spec_algorithm,
2243
2244
2245
2246
        )
        idle_batch.prepare_for_idle()
        return idle_batch

2247
2248
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
2249

2250
        num_ready_reqs = 0
2251
        num_timeout_reqs = 0
2252
2253
        for req in self.grammar_queue:
            try:
2254
2255
2256
                if req.finished():  # It is aborted by AbortReq
                    num_ready_reqs += 1
                    continue
2257

2258
                req.grammar = req.grammar.result(timeout=0.03)
2259
2260
                self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
                if req.grammar is INVALID_GRAMMAR_OBJ:
2261
2262
2263
                    error_msg = f"Invalid grammar request: {req.grammar_key=}"
                    req.set_finish_with_abort(error_msg)

2264
2265
                num_ready_reqs += 1
            except futures._base.TimeoutError:
2266
                req.grammar_wait_ct += 1
2267
2268
                # NOTE(lianmin): this timeout is the waiting time of the above line. It is
                # not the waiting time from it enters the grammar queue.
2269
                if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
2270
                    num_timeout_reqs = 1
2271
2272
                break

2273
        if self.server_args.enable_dp_attention:
2274
2275
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
2276
        else:
2277
2278
2279
2280
2281
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
2282
            tensor = torch.tensor([num_ready_reqs, num_timeout_reqs], dtype=torch.int32)
2283
2284
2285
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
2286
            num_ready_reqs_max, num_timeout_reqs_max = tensor.tolist()
2287

2288
            for i in range(num_ready_reqs, num_ready_reqs_max):
2289
                req = self.grammar_queue[i]
2290
2291
                if req.finished():  # It is aborted by AbortReq
                    continue
2292
                req.grammar = req.grammar.result()
2293
2294
                self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
                if req.grammar is INVALID_GRAMMAR_OBJ:
2295
2296
                    error_msg = f"Invalid grammar request: {req.grammar_key=}"
                    req.set_finish_with_abort(error_msg)
2297
2298
2299
        else:
            num_ready_reqs_max = num_ready_reqs
            num_timeout_reqs_max = num_timeout_reqs
2300

2301
2302
2303
        for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
            req = self.grammar_queue[i]
            req.grammar.cancel()
2304
            self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
2305
2306
            error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
            req.set_finish_with_abort(error_msg)
2307

2308
        num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
2309

2310
2311
        for req in self.grammar_queue[:num_ready_reqs]:
            self._add_request_to_queue(req)
2312
2313
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

2314
2315
2316
    def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
        success = self.flush_cache()
        return FlushCacheReqOutput(success=success)
2317

2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
    def clear_hicache_storage_wrapped(self, recv_req: ClearHiCacheReqInput):
        if self.enable_hierarchical_cache:
            self.tree_cache.clear_storage_backend()
            logger.info("Hierarchical cache cleared successfully!")
            if_success = True
        else:
            logging.warning("Hierarchical cache is not enabled.")
            if_success = False
        return ClearHiCacheReqOutput(success=if_success)

cctry's avatar
cctry committed
2328
2329
    def _is_no_request(self):
        no_request = (
2330
2331
            len(self.waiting_queue) == 0
            and self.running_batch.is_empty()
cctry's avatar
cctry committed
2332
2333
2334
            and (self.last_batch is None or self.last_batch.is_empty())
            and (self.cur_batch is None or self.cur_batch.is_empty())
            and (not self.enable_overlap or len(self.result_queue) == 0)
2335
            and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
cctry's avatar
cctry committed
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
        )
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            no_request &= (
                len(self.disagg_prefill_bootstrap_queue.queue) == 0
                and len(self.disagg_prefill_inflight_queue) == 0
            )
        if self.disaggregation_mode == DisaggregationMode.DECODE:
            no_request &= (
                len(self.disagg_decode_prealloc_queue.queue) == 0
                and len(self.disagg_decode_transfer_queue.queue) == 0
            )
        return no_request

    def flush_cache(self):
        """Flush the memory pool and cache."""
        if self._is_no_request():
2352
2353
            self.cur_batch = None
            self.last_batch = None
2354
            self.tree_cache.reset()
2355
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
2356
                self.grammar_backend.reset()
2357
            self.req_to_token_pool.clear()
2358
            self.token_to_kv_pool_allocator.clear()
2359

2360
2361
            if self.draft_worker:
                self.draft_worker.clear_cache_pool()
2362
2363
2364

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
2365
2366
2367
2368
            self.spec_num_accepted_tokens = 0
            self.spec_num_forward_ct = 0
            self.spec_total_num_accepted_tokens = 0
            self.spec_total_num_forward_ct = 0
2369
2370
2371
2372
2373
2374
2375
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
2376
                f"#running-req: {len(self.running_batch.reqs)}"
2377
2378
2379
2380
            )
            if_success = False
        return if_success

2381
    def get_load(self, recv_req: GetLoadReqInput = None) -> GetLoadReqOutput:
Liangsheng Yin's avatar
Liangsheng Yin committed
2382
        # TODO(lsyin): use dynamically maintained num_waiting_tokens
2383

Hanming Lu's avatar
Hanming Lu committed
2384
        if self.is_hybrid:
2385
            num_tokens_full = (
Hanming Lu's avatar
Hanming Lu committed
2386
2387
2388
2389
                self.full_tokens_per_layer
                - self.token_to_kv_pool_allocator.full_available_size()
                - self.tree_cache.full_evictable_size()
            )
2390
            num_tokens_swa = (
Hanming Lu's avatar
Hanming Lu committed
2391
2392
2393
2394
                self.swa_tokens_per_layer
                - self.token_to_kv_pool_allocator.swa_available_size()
                - self.tree_cache.swa_evictable_size()
            )
2395
            num_tokens = max(num_tokens_full, num_tokens_swa)
2396
2397
2398
2399
2400
2401
        elif self.is_hybrid_gdn:
            num_tokens = (
                self.max_total_num_tokens
                - self.token_to_kv_pool_allocator.available_size()
                - self.tree_cache.full_evictable_size()
            )
Hanming Lu's avatar
Hanming Lu committed
2402
        else:
2403
            num_tokens = (
Hanming Lu's avatar
Hanming Lu committed
2404
2405
2406
2407
                self.max_total_num_tokens
                - self.token_to_kv_pool_allocator.available_size()
                - self.tree_cache.evictable_size()
            )
2408
2409
2410
2411

        # Tokens in waiting queue, bootstrap queue, prealloc queue
        num_tokens += sum(len(req.origin_input_ids) for req in self.waiting_queue)
        num_waiting_reqs = len(self.waiting_queue)
Liangsheng Yin's avatar
Liangsheng Yin committed
2412
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
2413
            num_tokens += sum(
Liangsheng Yin's avatar
Liangsheng Yin committed
2414
2415
2416
                len(req.origin_input_ids)
                for req in self.disagg_prefill_bootstrap_queue.queue
            )
2417
            num_waiting_reqs += len(self.disagg_prefill_bootstrap_queue.queue)
Liangsheng Yin's avatar
Liangsheng Yin committed
2418
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
2419
            num_tokens += sum(
Liangsheng Yin's avatar
Liangsheng Yin committed
2420
2421
2422
                len(req.req.origin_input_ids)
                for req in self.disagg_decode_prealloc_queue.queue
            )
2423
            num_waiting_reqs += len(self.disagg_decode_prealloc_queue.queue)
Liangsheng Yin's avatar
Liangsheng Yin committed
2424

2425
2426
2427
2428
2429
2430
        return GetLoadReqOutput(
            dp_rank=self.dp_rank,
            num_reqs=len(self.running_batch.reqs) + num_waiting_reqs,
            num_waiting_reqs=num_waiting_reqs,
            num_tokens=num_tokens,
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
2431

2432
    def get_internal_state(self, recv_req: GetInternalStateReq):
2433
        ret = vars(get_global_server_args())
2434
        ret["last_gen_throughput"] = self.last_gen_throughput
2435
        ret["memory_usage"] = {
2436
            "weight": round(self.tp_worker.model_runner.weight_load_mem_usage, 2),
2437
2438
2439
2440
2441
            "kvcache": round(
                self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
            ),
            "token_capacity": int(self.max_total_num_tokens),
        }
2442

2443
        ret["memory_usage"]["graph"] = round(
2444
            self.tp_worker.model_runner.graph_mem_usage, 2
2445
        )
2446

2447
        if not self.spec_algorithm.is_none() and self.spec_total_num_forward_ct > 0:
2448
            ret["avg_spec_accept_length"] = (
2449
                self.spec_total_num_accepted_tokens / self.spec_total_num_forward_ct
2450
2451
2452
            )
        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
Liangsheng Yin's avatar
Liangsheng Yin committed
2453

2454
2455
2456
        # This field is not serializable.
        ret.pop("model_config", None)

Liangsheng Yin's avatar
Liangsheng Yin committed
2457
        return GetInternalStateReqOutput(internal_state=ret)
2458
2459
2460
2461
2462

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
2463
                "pp_max_micro_batch_size",
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
2474
            elif k == "pp_max_micro_batch_size" and (
2475
2476
2477
2478
2479
2480
2481
                v > self.max_running_requests // self.pp_size or v < 1
            ):
                logging.warning(
                    f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
                )
                if_success = False
                break
2482
        if if_success:
2483
            if not self.spec_algorithm.is_none() and self.spec_total_num_forward_ct > 0:
2484
                avg_spec_accept_length = (
2485
                    self.spec_total_num_accepted_tokens / self.spec_total_num_forward_ct
2486
2487
                )
                logger.info(f"{avg_spec_accept_length=}")
2488
            self.spec_total_num_accepted_tokens = self.spec_total_num_forward_ct = 0
2489
            for k, v in server_args_dict.items():
2490
2491
                setattr(get_global_server_args(), k, v)
            logger.info(f"Global server args updated! {get_global_server_args()=}")
2492
2493
        return SetInternalStateReqOutput(
            updated=True,
2494
            server_args=vars(get_global_server_args()),
2495
2496
        )

2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
    def handle_rpc_request(self, recv_req: RpcReqInput):
        # Handle RPC requests
        logger.info(
            f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
        )

        success = True
        exec = None
        try:
            func = getattr(self, recv_req.method)
            func(recv_req.parameters)
        except Exception as e:
            success = False
            exec = e
            logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")

        barrier()
        return RpcReqOutput(success, "" if not exec else str(exec))

2516
2517
    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
Lianmin Zheng's avatar
Lianmin Zheng committed
2518
        to_del = []
2519
        for i, req in enumerate(self.waiting_queue):
2520
            if recv_req.abort_all or req.rid.startswith(recv_req.rid):
Lianmin Zheng's avatar
Lianmin Zheng committed
2521
                to_del.append(i)
2522

Lianmin Zheng's avatar
Lianmin Zheng committed
2523
        # Sort in reverse order to avoid index issues when deleting
Lianmin Zheng's avatar
Lianmin Zheng committed
2524
        for i in reversed(to_del):
2525
2526
2527
            # Abort method 1: directly pop from the queue
            # This only works for requests that have not started anything.
            # We still need to send something back to TokenizerManager to clean up the state.
Lianmin Zheng's avatar
Lianmin Zheng committed
2528
            req = self.waiting_queue.pop(i)
2529
2530
2531
            if self.enable_hicache_storage:
                # to release prefetch events associated with the request
                self.tree_cache.release_aborted_request(req.rid)
2532
            self.send_to_tokenizer.send_output(AbortReq(rid=req.rid), req)
2533
2534
2535
2536
            # For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
            if self.disaggregation_mode == DisaggregationMode.DECODE:
                self.tree_cache.cache_finished_req(req)

2537
2538
2539
            # For mamba radix cache
            if req.mamba_pool_idx is not None:
                self.tree_cache.cache_finished_req(req, is_insert=False)
2540
            logger.debug(f"Abort queued request. {req.rid=}")
2541

2542
2543
2544
2545
2546
        # Delete the requests in the grammar queue
        for req in self.grammar_queue:
            # Abort method 2: call `set_finish_with_abort`
            # The request will still run one prefill forward pass.
            # In this case, we change the input_ids to be only one token to make this prefill cheap.
2547
            if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2548
                logger.debug(f"Abort grammar queue request. {req.rid=}")
2549
2550
                if req.grammar:
                    req.grammar.cancel()
2551
2552
                req.set_finish_with_abort("Aborted by AbortReq.")

2553
2554
2555
        # Delete requests not in the waiting queue when PD disaggregation is enabled
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            # Abort requests that have not yet been bootstrapped
2556
            for req in self.disagg_prefill_bootstrap_queue.queue:
2557
                if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2558
                    logger.debug(f"Abort bootstrap queue request. {req.rid=}")
2559
2560
2561
2562
                    if hasattr(req.disagg_kv_sender, "abort"):
                        req.disagg_kv_sender.abort()

            # Abort in-flight requests
2563
            for req in self.disagg_prefill_inflight_queue:
2564
                if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2565
                    logger.debug(f"Abort inflight queue request. {req.rid=}")
2566
2567
2568
2569
2570
                    if hasattr(req.disagg_kv_sender, "abort"):
                        req.disagg_kv_sender.abort()

        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            # Abort requests that have not yet finished preallocation
2571
            for decode_req in self.disagg_decode_prealloc_queue.queue:
2572
                if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
2573
                    logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
2574
2575
2576
2577
                    if hasattr(decode_req.kv_receiver, "abort"):
                        decode_req.kv_receiver.abort()

            # Abort requests waiting for kvcache to release tree cache
2578
            for decode_req in self.disagg_decode_transfer_queue.queue:
2579
                if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
2580
                    logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
2581
2582
2583
                    if hasattr(decode_req.kv_receiver, "abort"):
                        decode_req.kv_receiver.abort()

2584
        # Delete requests in the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
2585
2586
2587
2588
2589
2590
        if self.cur_batch is self.running_batch or self.cur_batch is None:
            reqs = self.running_batch.reqs
        else:
            reqs = self.running_batch.reqs + self.cur_batch.reqs

        for req in reqs:
2591
2592
2593
            if not req.finished() and (
                recv_req.abort_all or req.rid.startswith(recv_req.rid)
            ):
2594
                # Abort method 3: set `to_finish`
2595
2596
                # The request will still run one decode forward pass.
                # Then we reuse all existing code to clean up the KV cache allocation.
Lianmin Zheng's avatar
Lianmin Zheng committed
2597
                logger.debug(f"Abort running request. {req.rid=}")
2598
                req.to_finish = FINISH_ABORT()
2599

2600
2601
2602
    def _pause_engine(self) -> Tuple[List[Req], int]:
        raise NotImplementedError()

2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
    def load_lora_adapter(
        self, recv_req: LoadLoRAAdapterReqInput
    ) -> LoadLoRAAdapterReqOutput:
        """In-place loading a new lora adapter from disk or huggingface."""

        result = self.tp_worker.load_lora_adapter(recv_req)
        return result

    def unload_lora_adapter(
        self, recv_req: UnloadLoRAAdapterReqInput
    ) -> UnloadLoRAAdapterReqOutput:
        """Unload the lora adapter."""

        result = self.tp_worker.unload_lora_adapter(recv_req)
        return result

2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
    def init_weights_send_group_for_remote_instance(
        self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
    ):
        """Init the seed and client instance communication group."""
        success, message = self.tp_worker.init_weights_send_group_for_remote_instance(
            recv_req
        )
        return InitWeightsSendGroupForRemoteInstanceReqOutput(success, message)

    def send_weights_to_remote_instance(
        self, recv_req: SendWeightsToRemoteInstanceReqInput
    ):
        """Send the seed instance weights to the destination instance."""
        success, message = self.tp_worker.send_weights_to_remote_instance(recv_req)
        return SendWeightsToRemoteInstanceReqOutput(success, message)

2635
2636
2637
2638
2639
2640
2641
    def slow_down(self, recv_req: SlowDownReqInput):
        t = recv_req.forward_sleep_time
        if t is not None and t <= 0:
            t = None
        self.forward_sleep_time = t
        return SlowDownReqOutput()

2642
    def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
2643
2644
        action = recv_req.action
        if action == ExpertDistributionReqType.START_RECORD:
2645
            get_global_expert_distribution_recorder().start_record()
2646
        elif action == ExpertDistributionReqType.STOP_RECORD:
2647
            get_global_expert_distribution_recorder().stop_record()
2648
        elif action == ExpertDistributionReqType.DUMP_RECORD:
2649
            get_global_expert_distribution_recorder().dump_record()
2650
        else:
2651
            raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
2652
        return ExpertDistributionReqOutput()
2653

2654
    def open_session(self, recv_req: OpenSessionReqInput):
2655
2656
2657
2658
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
2659
            return OpenSessionReqOutput(session_id, False)
2660
        elif session_id is None:
2661
            logger.warning("session id is None, cannot open.")
2662
            return OpenSessionReqOutput(session_id, False)
2663
2664
2665
2666
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
2667
            return OpenSessionReqOutput(session_id, True)
2668
2669
2670
2671
2672
2673
2674
2675
2676

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

2677
2678
    def get_print_prefix(self):
        prefix = ""
2679
2680
        if self.attn_dp_rank is not None:
            prefix += f" DP{self.attn_dp_rank}"
2681
2682
2683
2684
2685
2686
        if self.server_args.tp_size > 1:
            prefix += f" TP{self.tp_rank}"
        if self.pp_size > 1:
            prefix += f" PP{self.pp_rank}"
        return prefix

2687
2688
    def current_scheduler_metrics_enabled(self):
        return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
2689

2690
2691
2692
    def maybe_sleep_on_idle(self):
        if self.idle_sleeper is not None:
            self.idle_sleeper.maybe_sleep()
2693

2694
2695
2696
    def handle_freeze_gc(self, recv_req: FreezeGCReq):
        """Handle freeze_gc request: freeze scheduler's GC and forward to detokenizer."""
        freeze_gc("Scheduler")
2697
        self.send_to_detokenizer.send_output(recv_req, recv_req)
2698
2699
        return None

2700

2701
2702
2703
2704
2705
2706
2707
class IdleSleeper:
    """
    In setups which have long inactivity periods it is desirable to reduce
    system power consumption when sglang does nothing. This would lead not only
    to power savings, but also to more CPU thermal headroom when a request
    eventually comes. This is important in cases when multiple GPUs are connected
    as each GPU would otherwise pin one thread at 100% CPU usage.
2708

2709
2710
2711
    The simplest solution is to use zmq.Poller on all sockets that may receive
    data that needs handling immediately.
    """
2712

2713
2714
    def __init__(self, sockets):
        self.poller = zmq.Poller()
2715
        self.last_empty_time = time.time()
2716
2717
2718
        for s in sockets:
            self.poller.register(s, zmq.POLLIN)

2719
2720
        self.empty_cache_interval = envs.SGLANG_EMPTY_CACHE_INTERVAL.get()

2721
2722
    def maybe_sleep(self):
        self.poller.poll(1000)
2723
        if (
2724
2725
            self.empty_cache_interval > 0
            and time.time() - self.last_empty_time > self.empty_cache_interval
2726
2727
2728
        ):
            self.last_empty_time = time.time()
            torch.cuda.empty_cache()
2729

2730

2731
def is_health_check_generate_req(recv_req):
2732
2733
    rid = getattr(recv_req, "rid", None)
    return rid is not None and rid.startswith("HEALTH_CHECK")
2734

2735
2736

def is_work_request(recv_req):
2737
2738
2739
2740
2741
2742
2743
2744
2745
    return isinstance(
        recv_req,
        (
            TokenizedGenerateReqInput,
            TokenizedEmbeddingReqInput,
            BatchTokenizedGenerateReqInput,
            BatchTokenizedEmbeddingReqInput,
        ),
    )
2746
2747


2748
2749
2750
2751
2752
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
Cheng Wan's avatar
Cheng Wan committed
2753
    moe_ep_rank: int,
2754
    pp_rank: int,
2755
    dp_rank: Optional[int],
2756
    pipe_writer,
2757
):
2758
    # Generate the logger prefix
2759
    prefix = ""
2760
2761
2762
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
2763
2764
2765
2766
    if dp_rank is not None:
        prefix += f" DP{dp_rank}"
    if server_args.tp_size > 1:
        prefix += f" TP{tp_rank}"
Cheng Wan's avatar
Cheng Wan committed
2767
2768
    if server_args.ep_size > 1:
        prefix += f" EP{moe_ep_rank}"
2769
2770
    if server_args.pp_size > 1:
        prefix += f" PP{pp_rank}"
2771

2772
    # Config the process
2773
    setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
2774
    faulthandler.enable()
2775
    kill_itself_when_parent_died()
2776
    parent_process = psutil.Process().parent()
2777

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
2778
    # Configure the logger
2779
    configure_logger(server_args, prefix=prefix)
2780
    suppress_other_loggers()
2781

2782
    # Set cpu affinity to this gpu process
2783
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
2784
2785
2786
        set_gpu_proc_affinity(
            server_args.pp_size, server_args.tp_size, server_args.nnodes, gpu_id
        )
2787
2788
2789
2790
2791
    if (numa_node := server_args.numa_node) is not None:
        numa_bind_to_node(numa_node[gpu_id])

    # Set up tracing
    if server_args.enable_trace:
2792
2793
2794
2795
2796
2797
2798
        process_tracing_init(server_args.otlp_traces_endpoint, "sglang")
        thread_label = "Scheduler"
        if server_args.disaggregation_mode == "prefill":
            thread_label = "Prefill Scheduler"
        elif server_args.disaggregation_mode == "decode":
            thread_label = "Decode Scheduler"
        trace_set_thread_info(thread_label, tp_rank, dp_rank)
2799

2800
    # Create a scheduler and run the event loop
2801
    try:
Cheng Wan's avatar
Cheng Wan committed
2802
        scheduler = Scheduler(
2803
2804
2805
2806
2807
2808
2809
            server_args,
            port_args,
            gpu_id,
            tp_rank,
            moe_ep_rank,
            pp_rank,
            dp_rank,
Cheng Wan's avatar
Cheng Wan committed
2810
        )
2811
        pipe_writer.send(
Mick's avatar
Mick committed
2812
2813
2814
2815
2816
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2817
        )
Byron Hsu's avatar
Byron Hsu committed
2818

2819
        disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
Byron Hsu's avatar
Byron Hsu committed
2820
        if disaggregation_mode == DisaggregationMode.NULL:
2821
2822
2823
            if scheduler.enable_pdmux:
                scheduler.event_loop_pdmux()
            elif server_args.pp_size > 1:
2824
2825
                scheduler.event_loop_pp()
            elif scheduler.enable_overlap:
Byron Hsu's avatar
Byron Hsu committed
2826
2827
2828
2829
                scheduler.event_loop_overlap()
            else:
                scheduler.event_loop_normal()
        elif disaggregation_mode == DisaggregationMode.PREFILL:
2830
2831
2832
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_prefill()
            else:
2833
2834
2835
2836
                if server_args.pp_size > 1:
                    scheduler.event_loop_pp_disagg_prefill()
                else:
                    scheduler.event_loop_normal_disagg_prefill()
2837

Byron Hsu's avatar
Byron Hsu committed
2838
        elif disaggregation_mode == DisaggregationMode.DECODE:
2839
2840
2841
2842
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_decode()
            else:
                scheduler.event_loop_normal_disagg_decode()
Byron Hsu's avatar
Byron Hsu committed
2843

2844
    except Exception:
2845
2846
2847
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)