scheduler.py 107 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
import time
23
from collections import deque
Lianmin Zheng's avatar
Lianmin Zheng committed
24
from concurrent import futures
25
from dataclasses import dataclass
26
from http import HTTPStatus
27
from types import SimpleNamespace
28
from typing import Dict, List, Optional, Tuple, Union
29

30
import psutil
31
import setproctitle
32
import torch
33
import zmq
34
from torch.distributed import barrier
35

36
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
37
from sglang.srt.configs.model_config import ModelConfig
38
39
40
41
from sglang.srt.constrained.base_grammar_backend import (
    INVALID_GRAMMAR_OBJ,
    create_grammar_backend,
)
Byron Hsu's avatar
Byron Hsu committed
42
43
44
45
46
47
48
49
50
51
52
from sglang.srt.disaggregation.decode import (
    DecodePreallocQueue,
    DecodeTransferQueue,
    SchedulerDisaggregationDecodeMixin,
)
from sglang.srt.disaggregation.prefill import (
    PrefillBootstrapQueue,
    SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
53
    MetadataBuffers,
Byron Hsu's avatar
Byron Hsu committed
54
    ReqToMetadataIdxAllocator,
55
    TransferBackend,
56
    prepare_abort,
Byron Hsu's avatar
Byron Hsu committed
57
)
58
from sglang.srt.distributed import get_pp_group, get_world_group
fzyzcjy's avatar
fzyzcjy committed
59
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
xm:D's avatar
xm:D committed
60
61
62
63
64
from sglang.srt.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
65
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
66
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
67
from sglang.srt.layers.moe import initialize_moe_config
68
69
from sglang.srt.managers.io_struct import (
    AbortReq,
70
71
    BatchTokenizedEmbeddingReqInput,
    BatchTokenizedGenerateReqInput,
72
    CloseSessionReqInput,
73
    ExpertDistributionReq,
74
    ExpertDistributionReqOutput,
75
76
    FlushCacheReqInput,
    FlushCacheReqOutput,
77
    FreezeGCReq,
78
79
    GetInternalStateReq,
    GetInternalStateReqOutput,
80
    GetWeightsByNameReqInput,
81
    HealthCheckOutput,
82
    InitWeightsUpdateGroupReqInput,
83
84
    LoadLoRAAdapterReqInput,
    LoadLoRAAdapterReqOutput,
85
86
    OpenSessionReqInput,
    OpenSessionReqOutput,
87
    ProfileReq,
88
89
    ReleaseMemoryOccupationReqInput,
    ResumeMemoryOccupationReqInput,
90
91
    RpcReqInput,
    RpcReqOutput,
92
93
    SetInternalStateReq,
    SetInternalStateReqOutput,
94
95
    SlowDownReqInput,
    SlowDownReqOutput,
96
97
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
98
99
    UnloadLoRAAdapterReqInput,
    UnloadLoRAAdapterReqOutput,
Chayenne's avatar
Chayenne committed
100
    UpdateWeightFromDiskReqInput,
101
    UpdateWeightsFromDistributedReqInput,
102
    UpdateWeightsFromTensorReqInput,
103
)
104
from sglang.srt.managers.mm_utils import init_embedding_cache
105
106
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
Mick's avatar
Mick committed
107
    MultimodalInputs,
108
109
    Req,
    ScheduleBatch,
110
    global_server_args_dict,
111
)
112
113
114
115
116
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
fzyzcjy's avatar
fzyzcjy committed
117
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
118
119
120
121
from sglang.srt.managers.scheduler_metrics_mixin import (
    RECORD_STEP_TIME,
    SchedulerMetricsMixin,
)
122
123
124
from sglang.srt.managers.scheduler_output_processor_mixin import (
    SchedulerOutputProcessorMixin,
)
125
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
126
from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
127
128
129
from sglang.srt.managers.scheduler_update_weights_mixin import (
    SchedulerUpdateWeightsMixin,
)
130
from sglang.srt.managers.session_controller import Session
131
from sglang.srt.managers.tp_worker import TpModelWorker
132
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
133
from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
tarinkk's avatar
tarinkk committed
134
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
135
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
136
from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
137
from sglang.srt.mem_cache.radix_cache import RadixCache
Hanming Lu's avatar
Hanming Lu committed
138
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
Lianmin Zheng's avatar
Lianmin Zheng committed
139
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
140
from sglang.srt.reasoning_parser import ReasoningParser
141
from sglang.srt.server_args import PortArgs, ServerArgs
142
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
143
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
144
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
145
from sglang.srt.utils import (
146
    DynamicGradMode,
147
    broadcast_pyobj,
fzyzcjy's avatar
fzyzcjy committed
148
    configure_gc_logger,
149
    configure_logger,
Lianmin Zheng's avatar
Lianmin Zheng committed
150
    disable_request_logging,
151
    freeze_gc,
152
    get_available_gpu_memory,
153
    get_bool_env_var,
154
    get_zmq_socket,
155
    is_cpu,
Lianmin Zheng's avatar
Lianmin Zheng committed
156
    kill_itself_when_parent_died,
157
    point_to_point_pyobj,
158
    pyspy_dump_schedulers,
159
160
    require_mlp_sync,
    require_mlp_tp_gather,
161
    set_gpu_proc_affinity,
162
163
164
    set_random_seed,
    suppress_other_loggers,
)
165
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
166
167
168

logger = logging.getLogger(__name__)

169
# Test retract decode for debugging purposes
170
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
171
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
172

173
174
_is_cpu = is_cpu()

175

176
177
@dataclass
class GenerationBatchResult:
178
179
180
    logits_output: Optional[LogitsProcessorOutput]
    pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
    next_token_ids: Optional[List[int]]
181
182
    extend_input_len_per_req: List[int]
    extend_logprob_start_len_per_req: List[int]
183
    bid: int
184
    can_run_cuda_graph: bool
185
186
187
188
189
190
191
192


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


Byron Hsu's avatar
Byron Hsu committed
193
194
class Scheduler(
    SchedulerOutputProcessorMixin,
195
196
197
    SchedulerUpdateWeightsMixin,
    SchedulerProfilerMixin,
    SchedulerMetricsMixin,
Byron Hsu's avatar
Byron Hsu committed
198
199
200
    SchedulerDisaggregationDecodeMixin,
    SchedulerDisaggregationPrefillMixin,
):
201
202
203
204
205
206
207
208
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
Cheng Wan's avatar
Cheng Wan committed
209
        moe_ep_rank: int,
210
        pp_rank: int,
211
        dp_rank: Optional[int],
212
        dp_balance_meta: Optional[DPBalanceMeta] = None,
213
214
    ):
        # Parse args
215
        self.server_args = server_args
216
        self.tp_rank = tp_rank
Cheng Wan's avatar
Cheng Wan committed
217
        self.moe_ep_rank = moe_ep_rank
218
        self.pp_rank = pp_rank
219
        self.dp_rank = dp_rank
220
        self.tp_size = server_args.tp_size
Cheng Wan's avatar
Cheng Wan committed
221
        self.moe_ep_size = server_args.ep_size
222
223
        self.pp_size = server_args.pp_size
        self.dp_size = server_args.dp_size
224
        self.schedule_policy = server_args.schedule_policy
225
        self.enable_lora = server_args.enable_lora
226
        self.max_loras_per_batch = server_args.max_loras_per_batch
227
        self.enable_overlap = not server_args.disable_overlap_schedule
228
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
229
        self.enable_metrics = server_args.enable_metrics
230
231
232
        self.enable_metrics_for_all_schedulers = (
            server_args.enable_metrics_for_all_schedulers
        )
233
        self.enable_kv_cache_events = server_args.kv_events_config is not None
234
        self.stream_interval = server_args.stream_interval
235
236
237
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
238
239
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
240
        self.enable_hicache_storage = server_args.hicache_storage_backend is not None
Lianmin Zheng's avatar
Lianmin Zheng committed
241
        self.page_size = server_args.page_size
242

243
        self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
244
245
246
247
248
249
250
251
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

252
253
254
        # Init model config
        self.model_config = ModelConfig.from_server_args(server_args)

255
256
        # Init inter-process communication
        context = zmq.Context(2)
257
        self.idle_sleeper = None
258

259
        if self.pp_rank == 0 and self.attn_tp_rank == 0:
260
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
261
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
262
            )
263
264
265
266
            self.recv_from_rpc = get_zmq_socket(
                context, zmq.DEALER, port_args.rpc_ipc_name, False
            )

267
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
268
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
269
            )
270
            if server_args.skip_tokenizer_init:
271
                # Directly send to the TokenizerManager
272
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
273
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
274
275
                )
            else:
276
                # Send to the DetokenizerManager
277
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
278
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
279
                )
280

281
282
283
284
285
286
287
            if self.server_args.sleep_on_idle:
                self.idle_sleeper = IdleSleeper(
                    [
                        self.recv_from_tokenizer,
                        self.recv_from_rpc,
                    ]
                )
288
        else:
289
            self.recv_from_tokenizer = None
290
            self.recv_from_rpc = None
291
292
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
293

294
295
296
297
298
        if self.current_scheduler_metrics_enabled():
            self.send_metrics_from_scheduler = get_zmq_socket(
                context, zmq.PUSH, port_args.metrics_ipc_name, False
            )

299
        # Init tokenizer
300
        self.init_tokenizer()
301

302
303
304
        # Init moe config
        self.init_moe_config()

305
306
307
308
309
310
311
312
313
        # Set reasoning_parser and think_end_id if --reasoning_parser is enabled
        if self.server_args.reasoning_parser and self.tokenizer:
            reasoning_parser = ReasoningParser(
                model_type=self.server_args.reasoning_parser, stream_reasoning=False
            )
            self.tokenizer.think_end_id = self.tokenizer.encode(
                reasoning_parser.detector.think_end_token, add_special_tokens=False
            )[0]

314
315
316
317
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
318

319
        # Launch a tensor parallel worker
320
        if self.enable_overlap:
321
            TpWorkerClass = TpModelWorkerClient
322
323
        else:
            TpWorkerClass = TpModelWorker
324

325
        self.tp_worker = TpWorkerClass(
326
            server_args=server_args,
327
328
            gpu_id=gpu_id,
            tp_rank=tp_rank,
Cheng Wan's avatar
Cheng Wan committed
329
            moe_ep_rank=moe_ep_rank,
330
            pp_rank=pp_rank,
331
            dp_rank=dp_rank,
332
            nccl_port=port_args.nccl_port,
333
        )
334

335
        # Launch a draft worker for speculative decoding
336
337
338
339
340
341
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
Cheng Wan's avatar
Cheng Wan committed
342
                moe_ep_rank=moe_ep_rank,
343
344
345
346
347
348
349
350
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
        else:
            self.draft_worker = None

351
        # Get token and memory info from the model worker
352
353
354
355
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
356
            self.max_queued_requests,
357
            self.max_req_len,
358
359
            self.max_req_input_len,
            self.random_seed,
360
            self.device,
361
362
363
364
365
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
366
367
368
369
370
371
372
373
        if global_server_args_dict["max_micro_batch_size"] is None:
            global_server_args_dict["max_micro_batch_size"] = max(
                self.max_running_requests // server_args.pp_size, 1
            )

        self.tp_group = self.tp_worker.get_tp_group()
        self.tp_cpu_group = self.tp_group.cpu_group
        self.attn_tp_group = self.tp_worker.get_attention_tp_group()
374
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
375
376
377
        self.pp_group = get_pp_group()
        self.world_group = get_world_group()

378
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
379
        global_server_args_dict.update(worker_global_server_args_dict)
380
        set_random_seed(self.random_seed)
381

382
        # Hybrid memory pool
Hanming Lu's avatar
Hanming Lu committed
383
384
385
386
387
388
389
        self.is_hybrid = self.tp_worker.is_hybrid
        if self.is_hybrid:
            self.sliding_window_size = self.tp_worker.sliding_window_size
            self.full_tokens_per_layer, self.swa_tokens_per_layer = (
                self.tp_worker.get_tokens_per_layer_info()
            )

390
        # Print debug info
391
        if tp_rank == 0:
392
393
394
            avail_mem = get_available_gpu_memory(
                self.device, self.gpu_id, empty_cache=False
            )
395
396
397
398
399
            logger.info(
                f"max_total_num_tokens={self.max_total_num_tokens}, "
                f"chunked_prefill_size={server_args.chunked_prefill_size}, "
                f"max_prefill_tokens={self.max_prefill_tokens}, "
                f"max_running_requests={self.max_running_requests}, "
400
401
                f"context_len={self.model_config.context_len}, "
                f"available_gpu_mem={avail_mem:.2f} GB"
402
            )
403

Lianmin Zheng's avatar
Lianmin Zheng committed
404
        # Init memory pool and cache
405
        self.init_memory_pool_and_cache()
406
407
408

        # Init running status
        self.waiting_queue: List[Req] = []
409
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
410
        self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
411
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
412
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
413
        # The last forward batch
414
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
415
416
        self.forward_ct = 0
        self.forward_ct_decode = 0
417
        self.num_generated_tokens = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
418
        self.last_prefill_tokens = 0
419
420
        self.last_decode_stats_tic = time.perf_counter()
        self.last_prefill_stats_tic = time.perf_counter()
421
        self.return_health_check_ct = 0
422
423
424
425
426
        self.num_retracted_reqs: int = 0
        self.num_paused_reqs: int = 0
        self.kv_transfer_speed_gb_s: float = 0.0
        self.kv_transfer_latency_ms: float = 0.0
        self.sessions: Dict[str, Session] = {}
427
        self.current_stream = torch.get_device_module(self.device).current_stream()
428
429
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
430
        self.forward_sleep_time = None
431

432
433
        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
434
435
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
436
        self.chunked_req = None
437
438
439
440
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
441
        # Init the grammar backend for constrained generation
442
        self.grammar_queue: List[Req] = []
443
        if not server_args.skip_tokenizer_init:
444
            self.grammar_backend = create_grammar_backend(
445
446
447
448
                server_args,
                self.tokenizer,
                self.model_config.vocab_size,
                self.model_config.hf_eos_token_id,
449
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
450
451
        else:
            self.grammar_backend = None
452

453
        # Init schedule policy and new token estimation
454
        self.policy = SchedulePolicy(
Lianmin Zheng's avatar
Lianmin Zheng committed
455
456
457
            self.schedule_policy,
            self.tree_cache,
            self.enable_hierarchical_cache,
458
        )
459
460
461
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
462
463
        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
464
465
            * server_args.schedule_conservativeness,
            1.0,
466
        )
467
468
469
470
471
472
473
474
475
476
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
477
478
479
480
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
481
        self.parent_process = psutil.Process().parent()
482
483

        # Init memory saver, profiler and metric stats
484
485
486
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )
487
        self.offload_tags = set()
488
        self.init_profier()
489

490
        self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
fzyzcjy's avatar
fzyzcjy committed
491
492
493
494
495
496
        self.input_blocker = (
            SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
            if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
            else None
        )

497
        # Init metrics stats
498
        self.init_metrics(tp_rank, pp_rank, dp_rank)
499
        self.init_kv_events(server_args.kv_events_config)
500

501
502
503
504
505
506
507
508
509
        # Init disaggregation
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
        self.init_disaggregation()

        if get_bool_env_var("SGLANG_GC_LOG"):
            configure_gc_logger()

510
511
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
512
513
514
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
515
516
                (BatchTokenizedGenerateReqInput, self.handle_batch_generate_request),
                (BatchTokenizedEmbeddingReqInput, self.handle_batch_embedding_request),
517
                (FlushCacheReqInput, self.flush_cache_wrapped),
518
                (AbortReq, self.abort_request),
519
520
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
521
522
523
524
525
526
527
528
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
529
530
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
531
                (SlowDownReqInput, self.slow_down),
532
                (ProfileReq, self.profile),
533
                (FreezeGCReq, self.handle_freeze_gc),
534
                (GetInternalStateReq, self.get_internal_state),
535
                (SetInternalStateReq, self.set_internal_state),
536
                (RpcReqInput, self.handle_rpc_request),
537
                (ExpertDistributionReq, self.expert_distribution_handle),
538
539
                (LoadLoRAAdapterReqInput, self.load_lora_adapter),
                (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
540
541
542
            ]
        )

543
544
545
546
547
548
549
550
551
        self.balance_meta = dp_balance_meta
        if (
            server_args.enable_dp_attention
            and server_args.load_balance_method == "minimum_tokens"
        ):
            assert dp_balance_meta is not None

        self.recv_dp_balance_id_this_term = []

552
553
554
    def init_tokenizer(self):
        server_args = self.server_args
        self.is_generation = self.model_config.is_generation
555

556
557
558
559
560
561
562
563
564
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if self.model_config.is_multimodal:
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
565
                    use_fast=not server_args.disable_fast_image_processor,
566
                )
xm:D's avatar
xm:D committed
567
                self.tokenizer = get_tokenizer_from_processor(self.processor)
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )

    def init_memory_pool_and_cache(self):
        server_args = self.server_args

        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            self.tp_worker.get_memory_pool()
        )

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
Hanming Lu's avatar
Hanming Lu committed
587
            if self.is_hybrid:
tarinkk's avatar
tarinkk committed
588
589
590
591
                ChunkCacheClass = SWAChunkCache
            else:
                ChunkCacheClass = ChunkCache
            self.tree_cache = ChunkCacheClass(
592
593
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
594
                page_size=self.page_size,
595
596
            )
        else:
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
            if os.environ.get("SGLANG_EXPERIMENTAL_CPP_RADIX_TREE") == "1":
                # lazy import to avoid JIT overhead
                from sglang.srt.mem_cache.radix_cache_cpp import RadixCacheCpp

                self.tree_cache = RadixCacheCpp(
                    disable=False,
                    use_hicache=self.enable_hierarchical_cache,
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool=self.token_to_kv_pool_allocator,
                    tp_cache_group=self.tp_cpu_group,
                    page_size=self.page_size,
                    hicache_ratio=server_args.hicache_ratio,
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
                    enable_kv_cache_events=self.enable_kv_cache_events,
                )
            elif self.enable_hierarchical_cache:
614
615
616
                self.tree_cache = HiRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
617
618
619
620
621
                    tp_cache_group=(
                        self.attn_tp_cpu_group
                        if self.server_args.enable_dp_attention
                        else self.tp_cpu_group
                    ),
622
                    page_size=self.page_size,
623
                    hicache_ratio=server_args.hicache_ratio,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
624
625
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
626
                    hicache_io_backend=server_args.hicache_io_backend,
627
                    hicache_mem_layout=server_args.hicache_mem_layout,
628
                    hicache_storage_backend=server_args.hicache_storage_backend,
pansicheng's avatar
pansicheng committed
629
                    hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
630
                )
631
632
633
                self.tp_worker.register_hicache_layer_transfer_counter(
                    self.tree_cache.cache_controller.layer_done_counter
                )
Hanming Lu's avatar
Hanming Lu committed
634
635
636
637
638
639
640
641
642
643
644
            elif self.is_hybrid:
                assert (
                    self.server_args.disaggregation_mode == "null"
                ), "Hybrid mode does not support disaggregation yet"
                self.tree_cache = SWARadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                    sliding_window_size=self.sliding_window_size,
                    page_size=self.page_size,
                    disable=server_args.disable_radix_cache,
                )
645
646
647
648
649
650
651
652
653
654
655
656
657
            elif self.enable_lora:
                assert (
                    not self.enable_hierarchical_cache
                ), "LoRA radix cache doesn't support hierarchical cache"
                assert (
                    self.schedule_policy == "fcfs"
                ), "LoRA radix cache only supports FCFS policy"
                self.tree_cache = LoRARadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                    page_size=self.page_size,
                    disable=server_args.disable_radix_cache,
                )
658
659
660
661
            else:
                self.tree_cache = RadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
662
                    page_size=self.page_size,
663
                    disable=server_args.disable_radix_cache,
664
                    enable_kv_cache_events=self.enable_kv_cache_events,
665
666
667
668
669
670
671
672
673
674
675
676
                )

        self.decode_mem_cache_buf_multiplier = (
            1
            if self.spec_algorithm.is_none()
            else (
                server_args.speculative_num_draft_tokens
                + (
                    server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                )
            )
677
        )
678

679
680
681
        embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
        init_embedding_cache(embedding_cache_size * 1024 * 1024)

Byron Hsu's avatar
Byron Hsu committed
682
    def init_disaggregation(self):
683
684
685
686
        self.transfer_backend = TransferBackend(
            self.server_args.disaggregation_transfer_backend
        )

Byron Hsu's avatar
Byron Hsu committed
687
688
689
690
        if (
            self.disaggregation_mode == DisaggregationMode.DECODE
        ):  # *2 for the headroom.
            buffer_size = (self.req_to_token_pool.size) * 2
Byron Hsu's avatar
Byron Hsu committed
691
            self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
Byron Hsu's avatar
Byron Hsu committed
692
693
                buffer_size
            )
694
695
            self.disagg_metadata_buffers = MetadataBuffers(
                buffer_size,
696
697
                hidden_size=self.model_config.hf_text_config.hidden_size,
                dtype=self.model_config.dtype,
698
699
                custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
            )
Byron Hsu's avatar
Byron Hsu committed
700
701
702

            # The decode requests polling kv cache
            self.disagg_decode_transfer_queue = DecodeTransferQueue(
703
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
704
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
705
                tp_rank=self.tp_rank,
706
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
707
708
                scheduler=self,
                tree_cache=self.tree_cache,
Byron Hsu's avatar
Byron Hsu committed
709
710
711
712
713
714
            )

            # The decode requests pending for pre-allocation
            self.disagg_decode_prealloc_queue = DecodePreallocQueue(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Byron Hsu's avatar
Byron Hsu committed
715
716
717
718
719
                draft_token_to_kv_pool=(
                    None
                    if self.draft_worker is None
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
720
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
721
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
722
723
724
                scheduler=self,
                transfer_queue=self.disagg_decode_transfer_queue,
                tree_cache=self.tree_cache,
725
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
726
727
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
728
729
                dp_size=self.server_args.dp_size,
                gpu_id=self.gpu_id,
Byron Hsu's avatar
Byron Hsu committed
730
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
731
732
                max_total_num_tokens=self.max_total_num_tokens,
                prefill_pp_size=self.server_args.disaggregation_prefill_pp,
733
                num_reserved_decode_tokens=self.server_args.num_reserved_decode_tokens,
734
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
735
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
736

Byron Hsu's avatar
Byron Hsu committed
737
738
739
        elif self.disaggregation_mode == DisaggregationMode.PREFILL:
            # *2 for the headroom.
            buffer_size = self.max_running_requests * 2
Byron Hsu's avatar
Byron Hsu committed
740
            self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
Byron Hsu's avatar
Byron Hsu committed
741
742
                buffer_size
            )
743
744
            self.disagg_metadata_buffers = MetadataBuffers(
                buffer_size,
745
746
                hidden_size=self.model_config.hf_text_config.hidden_size,
                dtype=self.model_config.dtype,
747
748
                custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
            )
Byron Hsu's avatar
Byron Hsu committed
749

Liangsheng Yin's avatar
Liangsheng Yin committed
750
            self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
Byron Hsu's avatar
Byron Hsu committed
751
                token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
Byron Hsu's avatar
Byron Hsu committed
752
753
754
755
756
                draft_token_to_kv_pool=(
                    None
                    if self.draft_worker is None
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
757
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
758
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
759
760
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
Byron Hsu's avatar
Byron Hsu committed
761
                gpu_id=self.gpu_id,
Byron Hsu's avatar
Byron Hsu committed
762
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
763
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
764
765
766
                max_total_num_tokens=self.max_total_num_tokens,
                decode_tp_size=self.server_args.disaggregation_decode_tp,
                decode_dp_size=self.server_args.disaggregation_decode_dp,
767
                scheduler=self,
Byron Hsu's avatar
Byron Hsu committed
768
769
770
                pp_rank=self.pp_rank,
                pp_size=self.pp_size,
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
771
772
            )
            # The prefill requests that are in the middle of kv sending
773
            self.disagg_prefill_inflight_queue: List[Req] = []
Byron Hsu's avatar
Byron Hsu committed
774

775
776
777
778
    def init_moe_config(self):
        if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
            initialize_moe_config(self.server_args)

779
    @DynamicGradMode()
780
    def event_loop_normal(self):
781
        """A normal scheduler loop."""
782
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
783
784
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
785

786
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
787
            self.cur_batch = batch
788
789
790
791

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
792
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
793
                # When the server is idle, do self-check and re-init some states
794
                self.self_check_during_idle()
795
796

            self.last_batch = batch
797

798
    @DynamicGradMode()
Lianmin Zheng's avatar
Lianmin Zheng committed
799
    def event_loop_overlap(self):
800
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
801
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
802
803
804
805
806
807
808

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
809

Lianmin Zheng's avatar
Lianmin Zheng committed
810
            if batch:
811
                batch.launch_done = threading.Event()
Lianmin Zheng's avatar
Lianmin Zheng committed
812
                result = self.run_batch(batch)
813
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
814

815
                if self.last_batch is None:
816
                    # Create a dummy first batch to start the pipeline for overlap schedule.
817
818
819
820
821
822
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
823
                    self.process_batch_result(tmp_batch, None, batch.launch_done)
824

Lianmin Zheng's avatar
Lianmin Zheng committed
825
            if self.last_batch:
826
                # Process the results of the last batch
827
                tmp_batch, tmp_result = self.result_queue.popleft()
828
829
830
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
831
832
833
834
                # NOTE: we should use current launched batch's launch_done event Instead of the last batch's
                self.process_batch_result(
                    tmp_batch, tmp_result, batch.launch_done if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
835
            elif batch is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
836
                # When the server is idle, do self-check and re-init some states
837
                self.self_check_during_idle()
Lianmin Zheng's avatar
Lianmin Zheng committed
838
839
840

            self.last_batch = batch

841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
    @DynamicGradMode()
    def event_loop_pp(self):
        """A non-overlap scheduler loop for pipeline parallelism."""
        mbs = [None] * self.pp_size
        last_mbs = [None] * self.pp_size
        self.running_mbs = [
            ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
        ]
        bids = [None] * self.pp_size
        pp_outputs: Optional[PPProxyTensors] = None
        while True:
            server_is_idle = True
            for mb_id in range(self.pp_size):
                self.running_batch = self.running_mbs[mb_id]
                self.last_batch = last_mbs[mb_id]

                recv_reqs = self.recv_requests()
                self.process_input_requests(recv_reqs)
                mbs[mb_id] = self.get_next_batch_to_run()
                self.running_mbs[mb_id] = self.running_batch

                self.cur_batch = mbs[mb_id]
                if self.cur_batch:
                    server_is_idle = False
                    result = self.run_batch(self.cur_batch)

867
                # (last rank) send the outputs to the next step
868
869
870
871
872
873
                if self.pp_group.is_last_rank:
                    if self.cur_batch:
                        next_token_ids, bids[mb_id] = (
                            result.next_token_ids,
                            result.bid,
                        )
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
                        if self.cur_batch.return_logprob:
                            pp_outputs = PPProxyTensors(
                                {
                                    "next_token_ids": next_token_ids,
                                    "extend_input_len_per_req": result.extend_input_len_per_req,
                                    "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
                                }
                                | (
                                    {
                                        f"logits_output.{k}": v
                                        for k, v in result.logits_output.__dict__.items()
                                    }
                                    if result.logits_output is not None
                                    else {}
                                )
                            )
                        else:
                            pp_outputs = PPProxyTensors(
                                {
                                    "next_token_ids": next_token_ids,
                                }
                            )
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
                        # send the output from the last round to let the next stage worker run post processing
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                # receive outputs and post-process (filter finished reqs) the coming microbatch
                next_mb_id = (mb_id + 1) % self.pp_size
                next_pp_outputs = None
                if mbs[next_mb_id] is not None:
                    next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
                        self.pp_group.recv_tensor_dict(
                            all_gather_group=self.attn_tp_group
                        )
                    )
                    mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
912
913
914
915
916
917
918
919
920
                    logits_output_args = {
                        k[len("logits_output.") :]: v
                        for k, v in next_pp_outputs.tensors.items()
                        if k.startswith("logits_output.")
                    }
                    if len(logits_output_args) > 0:
                        logits_output = LogitsProcessorOutput(**logits_output_args)
                    else:
                        logits_output = None
921
                    output_result = GenerationBatchResult(
922
                        logits_output=logits_output,
923
924
                        pp_hidden_states_proxy_tensors=None,
                        next_token_ids=next_pp_outputs["next_token_ids"],
925
926
927
928
929
930
                        extend_input_len_per_req=next_pp_outputs.tensors.get(
                            "extend_input_len_per_req", None
                        ),
                        extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
                            "extend_logprob_start_len_per_req", None
                        ),
931
                        bid=bids[next_mb_id],
932
                        can_run_cuda_graph=result.can_run_cuda_graph,
933
934
935
936
                    )
                    self.process_batch_result(mbs[next_mb_id], output_result)
                    last_mbs[next_mb_id] = mbs[next_mb_id]

937
                # (not last rank)
938
939
940
                if not self.pp_group.is_last_rank:
                    if self.cur_batch:
                        bids[mb_id] = result.bid
941
942
                    # carry the outputs to the next stage
                    # send the outputs from the last round to let the next stage worker run post processing
943
944
945
946
947
948
949
                    if pp_outputs:
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                    # send out reqs to the next stage
950
                    dp_offset = self.attn_dp_rank * self.attn_tp_size
951
952
953
954
                    if self.attn_tp_rank == 0:
                        point_to_point_pyobj(
                            recv_reqs,
                            self.pp_rank * self.tp_size + dp_offset,
955
                            self.world_group.device_group,
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
                            self.pp_rank * self.tp_size + dp_offset,
                            (self.pp_rank + 1) * self.tp_size + dp_offset,
                        )

                    # send out proxy tensors to the next stage
                    if self.cur_batch:
                        self.pp_group.send_tensor_dict(
                            result.pp_hidden_states_proxy_tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                pp_outputs = next_pp_outputs

            # When the server is idle, self-check and re-init some states
            if server_is_idle:
971
972
                # When the server is idle, do self-check and re-init some states
                self.self_check_during_idle()
973

974
975
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
976
977
978
979
980
981
982
983

        if self.recv_skipper is not None:
            last_forward_mode = (
                self.last_batch.forward_mode if self.last_batch is not None else None
            )
            if not self.recv_skipper.handle(last_forward_mode):
                return []

984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
        if self.pp_rank == 0:
            if self.attn_tp_rank == 0:
                recv_reqs = []

                while True:
                    try:
                        recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_req)

                while True:
                    try:
                        recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_rpc)
            else:
                recv_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1003
        else:
1004
            if self.attn_tp_rank == 0:
1005
                dp_offset = self.attn_dp_rank * self.attn_tp_size
1006
1007
1008
                recv_reqs = point_to_point_pyobj(
                    [],
                    self.pp_rank * self.tp_size + dp_offset,
1009
                    self.world_group.device_group,
1010
1011
1012
1013
1014
                    (self.pp_rank - 1) * self.tp_size + dp_offset,
                    self.pp_rank * self.tp_size + dp_offset,
                )
            else:
                recv_reqs = None
1015

fzyzcjy's avatar
fzyzcjy committed
1016
1017
1018
        if self.input_blocker is not None:
            recv_reqs = self.input_blocker.handle(recv_reqs)

1019
1020
1021
1022
1023
1024
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
1025
1026
1027
1028
1029
1030
1031
                        req,
                        (
                            TokenizedGenerateReqInput,
                            TokenizedEmbeddingReqInput,
                            BatchTokenizedGenerateReqInput,
                            BatchTokenizedEmbeddingReqInput,
                        ),
1032
1033
1034
1035
1036
1037
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
1038
1039
1040
1041
1042
1043
1044
                        req,
                        (
                            TokenizedGenerateReqInput,
                            TokenizedEmbeddingReqInput,
                            BatchTokenizedGenerateReqInput,
                            BatchTokenizedEmbeddingReqInput,
                        ),
1045
1046
1047
1048
1049
1050
1051
1052
1053
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                work_reqs = broadcast_pyobj(
                    work_reqs,
1054
                    self.attn_tp_group.rank,
1055
                    self.attn_tp_cpu_group,
1056
                    src=self.attn_tp_group.ranks[0],
1057
1058
1059
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
1060
1061
1062
1063
                    control_reqs,
                    self.tp_group.rank,
                    self.tp_cpu_group,
                    src=self.tp_group.ranks[0],
1064
1065
1066
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
1067
1068
1069
1070
1071
1072
            recv_reqs = broadcast_pyobj(
                recv_reqs,
                self.tp_group.rank,
                self.tp_cpu_group,
                src=self.tp_group.ranks[0],
            )
1073
1074
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
1075
    def process_input_requests(self, recv_reqs: List):
1076
        for recv_req in recv_reqs:
1077
1078
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
1079
1080
1081
                self.chunked_req is not None
                or not self.running_batch.is_empty()
                or len(self.offload_tags) > 0
1082
1083
1084
1085
            ):
                self.return_health_check_ct += 1
                continue

1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
            # If it is a work request, accept or reject the request based on the request queue size.
            if is_work_request(recv_req):
                if len(self.waiting_queue) + 1 > self.max_queued_requests:
                    abort_req = AbortReq(
                        recv_req.rid,
                        finished_reason={
                            "type": "abort",
                            "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
                            "message": "The request queue is full.",
                        },
                    )
                    self.send_to_tokenizer.send_pyobj(abort_req)
                    continue
1099
            output = self._request_dispatcher(recv_req)
1100
            if output is not None:
1101
1102
1103
1104
1105
                if isinstance(output, RpcReqOutput):
                    if self.recv_from_rpc is not None:
                        self.recv_from_rpc.send_pyobj(output)
                else:
                    self.send_to_tokenizer.send_pyobj(output)
1106
1107
1108
1109
1110

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
1111
1112
1113
1114
1115
1116
        if (
            self.server_args.enable_dp_attention
            and self.server_args.load_balance_method == "minimum_tokens"
        ):
            self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)

1117
        # Create a new request
1118
1119
1120
1121
1122
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
1123
1124
1125
1126
1127
1128
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

1129
1130
1131
1132
            if recv_req.bootstrap_port is None:
                # Use default bootstrap port
                recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port

1133
1134
1135
1136
1137
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
1138
1139
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
1140
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
1141
                stream=recv_req.stream,
1142
                lora_id=recv_req.lora_id,
Rin Intachuen's avatar
Rin Intachuen committed
1143
                input_embeds=recv_req.input_embeds,
Lianmin Zheng's avatar
Lianmin Zheng committed
1144
                custom_logit_processor=recv_req.custom_logit_processor,
1145
                return_hidden_states=recv_req.return_hidden_states,
1146
                eos_token_ids=self.model_config.hf_eos_token_id,
1147
                bootstrap_host=recv_req.bootstrap_host,
1148
                bootstrap_port=recv_req.bootstrap_port,
1149
                bootstrap_room=recv_req.bootstrap_room,
1150
                data_parallel_rank=recv_req.data_parallel_rank,
1151
                vocab_size=self.model_config.vocab_size,
1152
1153
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
1154

1155
1156
1157
            if self.disaggregation_mode != DisaggregationMode.NULL:
                # Invalid request for disaggregated mode
                if recv_req.bootstrap_room is None:
1158
                    error_msg = (
1159
1160
1161
                        f"Invalid request: Disaggregated request received without "
                        f"boostrap room id. {req.rid=}"
                    )
1162
                    logger.error(error_msg)
1163
                    prepare_abort(req, error_msg, status_code=HTTPStatus.BAD_REQUEST)
1164
1165
1166
                    self.stream_output([req], req.return_logprob)
                    return

1167
1168
1169
1170
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
1171
                req.set_finish_with_abort(
1172
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
1173
                )
1174
                self._add_request_to_queue(req)
1175
1176
                return
        else:
1177
1178
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
1179
            req = session.create_req(recv_req, self.tokenizer)
1180
            if isinstance(req.finished_reason, FINISH_ABORT):
1181
                self._add_request_to_queue(req)
1182
                return
1183

1184
        # Handle multimodal inputs
Mick's avatar
Mick committed
1185
1186
        if recv_req.mm_inputs is not None:
            image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
1187
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
1188
            req.origin_input_ids = self.pad_input_ids_func(
1189
                req.origin_input_ids, image_inputs
1190
            )
1191
            req.extend_image_inputs(image_inputs)
1192

1193
            if len(req.origin_input_ids) >= self.max_req_input_len:
1194
1195
1196
1197
1198
                req.set_finish_with_abort(
                    error_msg=(
                        "Multimodal prompt is too long after expanding multimodal tokens. "
                        f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                    )
1199
                )
1200
                self._add_request_to_queue(req)
1201
1202
                return

1203
        # Validate prompt length
1204
1205
1206
1207
1208
1209
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
1210
            req.set_finish_with_abort(error_msg)
1211
            self._add_request_to_queue(req)
1212
            return
1213

1214
        # Copy more attributes
1215
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
1216
1217
1218
1219
1220
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(req.origin_input_ids) - 1
        else:
            req.logprob_start_len = recv_req.logprob_start_len

1221
        if req.logprob_start_len >= len(req.origin_input_ids):
1222
            error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
1223
            req.logprob_start_len = len(req.origin_input_ids) - 1
1224
            req.set_finish_with_abort(error_msg)
1225
1226
1227
            self._add_request_to_queue(req)
            return

1228
1229
1230
1231
1232
1233
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
1234
            self.max_req_len - len(req.origin_input_ids) - 1,
1235
1236
        )

1237
1238
1239
1240
1241
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
1242
            or req.sampling_params.ebnf is not None
1243
            or req.sampling_params.structural_tag is not None
1244
1245
1246
1247
1248
1249
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
1250
1251
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
1252
1253
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
1254

1255
1256
1257
1258
1259
            value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
            req.grammar = value

            if not cache_hit:
                req.grammar_key = key
1260
                add_to_grammar_queue = True
1261
1262
1263
1264
            else:
                if value is INVALID_GRAMMAR_OBJ:  # We hit a cached invalid grammar.
                    error_msg = f"Invalid grammar request with cache hit: {key=}"
                    req.set_finish_with_abort(error_msg)
1265
1266

        if add_to_grammar_queue:
1267
            req.queue_time_start = time.perf_counter()
1268
1269
            self.grammar_queue.append(req)
        else:
1270
1271
            self._add_request_to_queue(req)

1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
    def handle_batch_generate_request(
        self,
        recv_req: BatchTokenizedGenerateReqInput,
    ):
        """Handle optimized batch generate request."""
        logger.debug(f"Processing batch generate request with {len(recv_req)} requests")

        # Process each request in the batch
        for tokenized_req in recv_req:
            self.handle_generate_request(tokenized_req)

1283
    def _add_request_to_queue(self, req: Req):
1284
        req.queue_time_start = time.perf_counter()
Byron Hsu's avatar
Byron Hsu committed
1285
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
1286
            self._prefetch_kvcache(req)
Byron Hsu's avatar
Byron Hsu committed
1287
1288
1289
            self.disagg_prefill_bootstrap_queue.add(
                req, self.model_config.num_key_value_heads
            )
Byron Hsu's avatar
Byron Hsu committed
1290
1291
1292
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.add(req)
        else:
1293
            self._prefetch_kvcache(req)
Byron Hsu's avatar
Byron Hsu committed
1294
1295
            self.waiting_queue.append(req)

1296
1297
1298
    def _prefetch_kvcache(self, req: Req):
        if self.enable_hicache_storage:
            req.init_next_round_input(self.tree_cache)
1299
1300
1301
1302
1303
            if req.last_node.backuped:
                # only to initiate the prefetch if the last node is backuped
                # otherwise, the allocated GPU memory must be locked for integrity
                last_hash = req.last_host_node.get_last_hash_value()
                matched_len = len(req.prefix_indices) + req.host_hit_length
1304
1305
1306
1307
1308
                new_input_tokens = req.fill_ids[matched_len:]
                self.tree_cache.prefetch_from_storage(
                    req.rid, req.last_host_node, new_input_tokens, last_hash
                )

1309
    def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
1310
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
Byron Hsu's avatar
Byron Hsu committed
1311
1312
1313
            self.disagg_prefill_bootstrap_queue.extend(
                reqs, self.model_config.num_key_value_heads
            )
1314
1315
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            # If this is a decode server, we put the request to the decode pending prealloc queue
1316
            self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
Byron Hsu's avatar
Byron Hsu committed
1317
1318
        else:
            self.waiting_queue.extend(reqs)
1319
1320
1321

    def handle_embedding_request(
        self,
1322
        recv_req: TokenizedEmbeddingReqInput,
1323
1324
1325
1326
1327
1328
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
woodx's avatar
woodx committed
1329
            token_type_ids=recv_req.token_type_ids,
1330
1331
1332
        )
        req.tokenizer = self.tokenizer

1333
1334
        # Handle multimodal inputs
        if recv_req.image_inputs is not None:
Mick's avatar
Mick committed
1335
            image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
1336
1337
1338
1339
1340
1341
1342
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
            req.origin_input_ids = self.pad_input_ids_func(
                req.origin_input_ids, image_inputs
            )
            req.extend_image_inputs(image_inputs)

            if len(req.origin_input_ids) >= self.max_req_input_len:
1343
1344
1345
1346
1347
                req.set_finish_with_abort(
                    error_msg=(
                        "Multimodal prompt is too long after expanding multimodal tokens. "
                        f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                    )
1348
                )
1349
                self._add_request_to_queue(req)
1350
1351
                return

1352
        # Validate prompts length
1353
        error_msg = validate_input_length(
1354
1355
1356
1357
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
1358
        if error_msg:
1359
            self._add_request_to_queue(req)
1360
            return
1361

1362
1363
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
1364
        self._add_request_to_queue(req)
1365

1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
    def handle_batch_embedding_request(
        self,
        recv_req: BatchTokenizedEmbeddingReqInput,
    ):
        """Handle optimized batch embedding request."""
        logger.debug(
            f"Processing batch embedding request with {len(recv_req)} requests"
        )

        # Process each request in the batch
        for tokenized_req in recv_req:
            self.handle_embedding_request(tokenized_req)

1379
1380
1381
1382
1383
    def self_check_during_idle(self):
        self.check_memory()
        self.check_tree_cache()
        self.new_token_ratio = self.init_new_token_ratio
        self.maybe_sleep_on_idle()
1384

Lianmin Zheng's avatar
Lianmin Zheng committed
1385
    def check_memory(self):
Hanming Lu's avatar
Hanming Lu committed
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
        if self.is_hybrid:
            (
                full_num_used,
                swa_num_used,
                _,
                _,
                full_available_size,
                full_evictable_size,
                swa_available_size,
                swa_evictable_size,
            ) = self._get_swa_token_info()
            memory_leak = full_num_used != 0 or swa_num_used != 0
            token_msg = (
                f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
                f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
            )
tarinkk's avatar
tarinkk committed
1402
        else:
Hanming Lu's avatar
Hanming Lu committed
1403
1404
1405
1406
1407
1408
            _, _, available_size, evictable_size = self._get_token_info()
            protected_size = self.tree_cache.protected_size()
            memory_leak = (available_size + evictable_size) != (
                self.max_total_num_tokens
                if not self.enable_hierarchical_cache
                else self.max_total_num_tokens - protected_size
Lianmin Zheng's avatar
Lianmin Zheng committed
1409
            )
Hanming Lu's avatar
Hanming Lu committed
1410
1411
1412
1413
            token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"

        if memory_leak:
            msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
Lianmin Zheng's avatar
Lianmin Zheng committed
1414
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1415

1416
1417
1418
1419
1420
1421
1422
1423
        if self.disaggregation_mode == DisaggregationMode.DECODE:
            req_total_size = (
                self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
            )
        else:
            req_total_size = self.req_to_token_pool.size

        if len(self.req_to_token_pool.free_slots) != req_total_size:
1424
            msg = (
1425
                "req_to_token_pool memory leak detected!"
1426
1427
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1428
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1429
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1430

1431
1432
        if (
            self.enable_metrics
1433
            and self.current_scheduler_metrics_enabled()
1434
            and time.perf_counter() > self.metrics_collector.last_log_time + 30
1435
1436
        ):
            # During idle time, also collect metrics every 30 seconds.
Hanming Lu's avatar
Hanming Lu committed
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
            if self.is_hybrid:
                (
                    full_num_used,
                    swa_num_used,
                    full_token_usage,
                    swa_token_usage,
                    _,
                    _,
                    _,
                    _,
                ) = self._get_swa_token_info()
                num_used = max(full_num_used, swa_num_used)
                token_usage = max(full_token_usage, swa_token_usage)
            else:
                num_used, token_usage, _, _ = self._get_token_info()
Lianmin Zheng's avatar
Lianmin Zheng committed
1452
            num_running_reqs = len(self.running_batch.reqs)
1453
1454
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
Hanming Lu's avatar
Hanming Lu committed
1455
            self.stats.token_usage = round(token_usage, 2)
1456
1457
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
1458
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1459
            self.metrics_collector.log_stats(self.stats)
1460
        self._publish_kv_events()
1461

Hanming Lu's avatar
Hanming Lu committed
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
    def check_tree_cache(self):
        if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache):
            self.tree_cache.sanity_check()

    def _get_token_info(self):
        available_size = self.token_to_kv_pool_allocator.available_size()
        evictable_size = self.tree_cache.evictable_size()
        num_used = self.max_total_num_tokens - (available_size + evictable_size)
        token_usage = num_used / self.max_total_num_tokens
        return num_used, token_usage, available_size, evictable_size

    def _get_swa_token_info(self):
        full_available_size = self.token_to_kv_pool_allocator.full_available_size()
        full_evictable_size = self.tree_cache.full_evictable_size()
        swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
        swa_evictable_size = self.tree_cache.swa_evictable_size()
        full_num_used = self.full_tokens_per_layer - (
            full_available_size + full_evictable_size
        )
        swa_num_used = self.swa_tokens_per_layer - (
            swa_available_size + swa_evictable_size
        )
        full_token_usage = full_num_used / self.full_tokens_per_layer
        swa_token_usage = swa_num_used / self.swa_tokens_per_layer
        return (
            full_num_used,
            swa_num_used,
            full_token_usage,
            swa_token_usage,
            full_available_size,
            full_evictable_size,
            swa_available_size,
            swa_evictable_size,
        )

1497
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1498
        # Merge the prefill batch into the running batch
1499
1500
1501
1502
1503
1504
1505
1506
        chunked_req_to_exclude = set()
        if self.chunked_req:
            # Move the chunked request out of the batch so that we can merge
            # only finished requests to running_batch.
            chunked_req_to_exclude.add(self.chunked_req)
            self.tree_cache.cache_unfinished_req(self.chunked_req)
            # chunked request keeps its rid but will get a new req_pool_idx
            self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
Lianmin Zheng's avatar
Lianmin Zheng committed
1507
        if self.last_batch and self.last_batch.forward_mode.is_extend():
1508
1509
1510
1511
            if self.last_batch.chunked_req is not None:
                # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
                # We need to discard it.
                chunked_req_to_exclude.add(self.last_batch.chunked_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1512

1513
            # Filter batch
1514
            last_bs = self.last_batch.batch_size()
1515
1516
1517
            self.last_batch.filter_batch(
                chunked_req_to_exclude=list(chunked_req_to_exclude)
            )
1518
            if self.last_batch.batch_size() < last_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1519
                self.running_batch.batch_is_full = False
1520

1521
1522
1523
            # Merge the new batch into the running batch.
            # For prefill-only batch, we can avoid going through decoding step.
            if not self.last_batch.is_empty() and not self.last_batch.is_prefill_only:
Lianmin Zheng's avatar
Lianmin Zheng committed
1524
                if self.running_batch.is_empty():
1525
1526
                    self.running_batch = self.last_batch
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1527
                    # Merge running_batch with prefill batch
1528
                    self.running_batch.merge_batch(self.last_batch)
1529

1530
        new_batch = self.get_new_batch_prefill()
1531

1532
1533
1534
1535
1536
        need_dp_attn_preparation = require_mlp_sync(self.server_args)

        if need_dp_attn_preparation and not self.spec_algorithm.is_none():
            # In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
            # We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
1537
            new_batch = self.prepare_mlp_sync_batch(new_batch)
1538
1539
1540
            need_dp_attn_preparation = new_batch is None

        if new_batch is not None:
1541
1542
1543
1544
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
1545
            if not self.running_batch.is_empty():
1546
                self.running_batch = self.update_running_batch(self.running_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1547
1548
1549
                ret = self.running_batch if not self.running_batch.is_empty() else None
            else:
                ret = None
1550

1551
1552
        # Handle DP attention
        if need_dp_attn_preparation:
1553
1554
1555
1556
1557
            if (
                self.server_args.load_balance_method == "minimum_tokens"
                and self.forward_ct % 40 == 0
            ):
                self.handle_dp_balance_data(ret)
1558
            ret = self.prepare_mlp_sync_batch(ret)
1559
1560

        return ret
1561

1562
1563
1564
1565
1566
1567
    def get_num_allocatable_reqs(self, running_bs):
        res = global_server_args_dict["max_micro_batch_size"] - running_bs
        if self.pp_size > 1:
            res = min(res, self.req_to_token_pool.available_size())
        return res

Lianmin Zheng's avatar
Lianmin Zheng committed
1568
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
1569
        # Check if the grammar is ready in the grammar queue
1570
        if self.grammar_queue:
1571
            self.move_ready_grammar_requests()
1572

Lianmin Zheng's avatar
Lianmin Zheng committed
1573
1574
        # Handle the cases where prefill is not allowed
        if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1575
            self.running_batch.batch_is_full or len(self.waiting_queue) == 0
1576
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1577
1578
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1579
        running_bs = len(self.running_batch.reqs)
1580
        # Ignore the check if self.chunked_req is not None.
1581
1582
1583
1584
1585
        # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
        # as the space for the chunked request has just been released.
        # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
        # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
        if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
Lianmin Zheng's avatar
Lianmin Zheng committed
1586
            self.running_batch.batch_is_full = True
1587
1588
            return None

1589
        if self.enable_hierarchical_cache:
1590
            self.tree_cache.check_hicache_events()
1591

1592
        # Get priority queue
1593
        self.policy.calc_priority(self.waiting_queue)
1594

Lianmin Zheng's avatar
Lianmin Zheng committed
1595
        # Prefill policy
1596
        adder = PrefillAdder(
1597
            self.page_size,
1598
            self.tree_cache,
1599
            self.token_to_kv_pool_allocator,
1600
1601
1602
1603
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1604
            running_bs if self.is_mixed_chunk else 0,
1605
1606
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1607
        if self.chunked_req is not None:
1608
1609
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1610

1611
        if self.enable_lora:
1612
            lora_set = set([req.lora_id for req in self.running_batch.reqs])
Lianmin Zheng's avatar
Lianmin Zheng committed
1613

1614
        # Get requests from the waiting queue to a new prefill batch
1615
        for req in self.waiting_queue:
1616
1617
1618
1619
1620

            if self.enable_lora and not self.tp_worker.can_run_lora_batch(
                lora_set
                | set([req.lora_id for req in adder.can_run_list])
                | set([req.lora_id])
1621
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1622
                self.running_batch.batch_is_full = True
1623
1624
                break

1625
            if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
Lianmin Zheng's avatar
Lianmin Zheng committed
1626
                self.running_batch.batch_is_full = True
1627
                break
1628

Byron Hsu's avatar
Byron Hsu committed
1629
1630
1631
1632
1633
1634
1635
            if self.disaggregation_mode == DisaggregationMode.PREFILL:
                # In prefill mode, prealloc queue and transfer queue can also take memory,
                # so we need to check if the available size for the actual available size.
                if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
                    self.running_batch.batch_is_full = True
                    break

1636
            if self.enable_hicache_storage:
pansicheng's avatar
pansicheng committed
1637
1638
1639
1640
                prefetch_done = self.tree_cache.check_prefetch_progress(req.rid)
                if not prefetch_done:
                    # skip staging requests that are ongoing prefetch
                    continue
1641

1642
1643
            req.init_next_round_input(self.tree_cache)
            res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
1644

1645
1646
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1647
1648
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
Lianmin Zheng's avatar
Lianmin Zheng committed
1649
1650
                        self.running_batch.batch_is_full = len(
                            adder.can_run_list
1651
                        ) > 0 or (not self.running_batch.is_empty())
1652
                    else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1653
                        self.running_batch.batch_is_full = True
1654
1655
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1656
        # Update waiting queue
1657
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1658
1659
        if len(can_run_list) == 0:
            return None
1660
1661
1662
1663

        if self.enable_metrics:
            # only record queue time when enable_metrics is True to avoid overhead
            for req in can_run_list:
1664
                req.queue_time_end = time.perf_counter()
1665

Lianmin Zheng's avatar
Lianmin Zheng committed
1666
1667
1668
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1669

1670
1671
1672
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1673

1674
1675
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1676

1677
        # Print stats
1678
        if self.current_scheduler_metrics_enabled():
1679
            self.log_prefill_stats(adder, can_run_list, running_bs)
1680

Lianmin Zheng's avatar
Lianmin Zheng committed
1681
        # Create a new batch
1682
1683
1684
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
1685
            self.token_to_kv_pool_allocator,
1686
            self.tree_cache,
1687
            self.model_config,
1688
            self.enable_overlap,
1689
            self.spec_algorithm,
1690
            chunked_req=self.chunked_req,
1691
        )
1692
1693
        if self.enable_hierarchical_cache:
            # todo (zhiqiang): disable cuda graph execution if hicache loading triggered
1694
1695
1696
            new_batch.hicache_consumer_index = (
                self.tree_cache.ready_to_load_host_cache()
            )
1697

1698
        new_batch.prepare_for_extend()
1699

Lianmin Zheng's avatar
Lianmin Zheng committed
1700
        # Mixed-style chunked prefill
1701
1702
        if (
            self.is_mixed_chunk
Lianmin Zheng's avatar
Lianmin Zheng committed
1703
            and not self.running_batch.is_empty()
1704
1705
1706
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1707
1708
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1709
                self.running_batch.prepare_for_decode()
1710
1711
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1712
1713
1714
            self.running_batch = ScheduleBatch(
                reqs=[], batch_is_full=self.running_batch.batch_is_full
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1715
1716
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1717
1718
1719

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1720
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1721
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1722
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1723

1724
1725
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1726
1727
            batch.batch_is_full = False
            return batch
1728

Lianmin Zheng's avatar
Lianmin Zheng committed
1729
        # Check if decode out of memory
1730
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1731
            TEST_RETRACT and batch.batch_size() > 10
1732
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1733
1734
            old_ratio = self.new_token_ratio

1735
            retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
1736
            num_retracted_reqs = len(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1737
            self.new_token_ratio = new_token_ratio
1738

Lianmin Zheng's avatar
Lianmin Zheng committed
1739
            logger.info(
1740
                "KV cache pool is full. Retract requests. "
1741
                f"#retracted_reqs: {num_retracted_reqs}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
1742
1743
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
1744

1745
            self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
1746
            self.total_retracted_reqs += num_retracted_reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1747
1748
        else:
            self.new_token_ratio = max(
1749
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1750
1751
1752
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1753
        if batch.batch_size() < initial_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1754
            batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1755
1756

        # Update batch tensors
1757
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1758
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1759

1760
1761
1762
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1763
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1764
1765
        self.forward_ct += 1

1766
1767
        # Whether to run the profiler
        self._profile_batch_predicate(batch)
1768
1769
1770
1771
        if self.forward_sleep_time is not None:
            logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
            time.sleep(self.forward_sleep_time)

1772
        # Run forward
1773
        if self.is_generation:
1774
1775
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
1776
1777
1778
1779
1780

                # update the consumer index of hicache to the running batch
                self.tp_worker.set_hicache_consumer(
                    model_worker_batch.hicache_consumer_index
                )
1781
                if self.pp_group.is_last_rank:
1782
                    logits_output, next_token_ids, can_run_cuda_graph = (
1783
1784
1785
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
                else:
1786
                    pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
1787
1788
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
1789
                bid = model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
1790
            else:
1791
1792
1793
                (
                    logits_output,
                    next_token_ids,
1794
                    bid,
1795
                    num_accepted_tokens,
1796
                    can_run_cuda_graph,
1797
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
1798
1799
1800
                bs = batch.batch_size()
                self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
                self.spec_num_total_forward_ct += bs
1801
                self.num_generated_tokens += num_accepted_tokens
1802
1803
1804

            if self.pp_group.is_last_rank:
                batch.output_ids = next_token_ids
1805

1806
1807
1808
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
1809
            if batch.return_logprob or self.spec_algorithm.is_eagle():
1810
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
1811
1812
1813
            else:
                extend_input_len_per_req = None
            if batch.return_logprob:
1814
1815
1816
1817
1818
1819
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_logprob_start_len_per_req = None

1820
            ret = GenerationBatchResult(
1821
1822
1823
1824
1825
1826
1827
                logits_output=logits_output if self.pp_group.is_last_rank else None,
                pp_hidden_states_proxy_tensors=(
                    pp_hidden_states_proxy_tensors
                    if not self.pp_group.is_last_rank
                    else None
                ),
                next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
1828
1829
                extend_input_len_per_req=extend_input_len_per_req,
                extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1830
                bid=bid,
1831
                can_run_cuda_graph=can_run_cuda_graph,
1832
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1833
1834
1835
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1836
1837
1838
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1839
        return ret
Chayenne's avatar
Chayenne committed
1840

1841
1842
1843
1844
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
1845
        launch_done: Optional[threading.Event] = None,
1846
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1847
        if batch.forward_mode.is_decode():
1848
            self.process_batch_result_decode(batch, result, launch_done)
1849
        elif batch.forward_mode.is_extend():
1850
            self.process_batch_result_prefill(batch, result, launch_done)
1851
1852
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1853
                self.tp_worker.resolve_last_batch_result(launch_done)
1854
                self.set_next_batch_sampling_info_done(batch)
1855
        elif batch.forward_mode.is_dummy_first():
1856
            self.set_next_batch_sampling_info_done(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1857

1858
1859
1860
        self.maybe_send_health_check_signal()

    def maybe_send_health_check_signal(self):
1861
1862
1863
1864
1865
1866
1867
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

1868
1869
    def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
        return self.prepare_mlp_sync_batch_raw(
1870
1871
1872
            local_batch,
            dp_size=self.server_args.dp_size,
            attn_tp_size=self.attn_tp_size,
1873
            tp_group=self.tp_group,
1874
1875
1876
1877
            get_idle_batch=self.get_idle_batch,
            disable_cuda_graph=self.server_args.disable_cuda_graph,
            spec_algorithm=self.spec_algorithm,
            speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1878
            require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
1879
            disable_overlap_schedule=self.server_args.disable_overlap_schedule,
1880
1881
        )

1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
    def handle_dp_balance_data(self, local_batch: ScheduleBatch):
        def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
            """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
            recv_list = self.recv_dp_balance_id_this_term
            assert len(recv_list) <= 511, (
                "The number of requests received this round is too large. "
                "Please increase gather_tensor_size and onfly_info_size."
            )
            # The maximum size of the tensor used for gathering data from all workers.
            gather_tensor_size = 512

            # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
            recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
            recv_tensor[0] = holding_tokens_list
            recv_tensor[1] = len(
                recv_list
            )  # The first element is the length of the list.
            recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
                recv_list, dtype=torch.int32
            )

            if self.tp_rank == 0:
                gathered_list = [
                    torch.zeros(gather_tensor_size, dtype=torch.int32)
                    for _ in range(self.balance_meta.num_workers)
                ]
            else:
                gathered_list = None

            torch.distributed.gather(
                recv_tensor, gathered_list, group=self.tp_cpu_group
            )

            gathered_id_list_per_worker = None
            if self.tp_rank == 0:
                gathered_id_list_per_worker = []
                holding_tokens_list = []
                for tensor in gathered_list:
                    holding_tokens_list.append(tensor[0].item())
                    list_length = tensor[1].item()
                    gathered_id_list_per_worker.append(
                        tensor[2 : list_length + 2].tolist()
                    )

            return gathered_id_list_per_worker, holding_tokens_list

        def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
            meta = self.balance_meta

            with meta.mutex:
                onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
                assert len(new_recv_rid_lists) == len(
                    onfly_list
                ), "num_worker not equal"
                # 1.Check if the rid received by each worker this round is present in onfly.
                #   If it is, remove the corresponding onfly item.
                worker_id = 0
                for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
                    for new_recv_rid in new_recv_rids:
                        assert (
                            new_recv_rid in on_fly_reqs
                        ), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
                        del on_fly_reqs[new_recv_rid]
                    worker_id += 1
                # 2. Atomically write local_tokens and onfly into shm under the mutex
                meta.set_shared_onfly_info(onfly_list)
                meta.set_shared_local_tokens(local_tokens)

        holding_tokens = self.get_load()

        new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
            holding_tokens
        )

        self.recv_dp_balance_id_this_term.clear()
        if self.tp_rank == 0:  # only first worker write info
            write_shared_dp_balance_info(
                new_recv_dp_balance_id_list, holding_token_list
            )

1962
    @staticmethod
1963
    def prepare_mlp_sync_batch_raw(
1964
1965
1966
        local_batch: ScheduleBatch,
        dp_size,
        attn_tp_size: int,
1967
        tp_group,
1968
1969
1970
1971
        get_idle_batch,
        disable_cuda_graph: bool,
        spec_algorithm,
        speculative_num_draft_tokens,
1972
        require_mlp_tp_gather: bool,
1973
        disable_overlap_schedule: bool,
1974
    ):
1975
1976
1977
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
1978
            num_tokens_for_logprob = 0
1979
1980
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
1981
            num_tokens_for_logprob = num_tokens
1982
1983
        else:
            num_tokens = local_batch.extend_num_tokens
1984
            num_tokens_for_logprob = sum(
Lianmin Zheng's avatar
Lianmin Zheng committed
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
                [
                    # We should have at least 1 token for sample in every case.
                    max(extend_len - logprob_start_len, 1)
                    for logprob_start_len, extend_len in zip(
                        local_batch.extend_logprob_start_lens, local_batch.extend_lens
                    )
                ]
            )

        if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
            can_cuda_graph = 1
        else:
            can_cuda_graph = 0

        is_extend_in_batch = (
            local_batch.forward_mode.is_extend() if local_batch else False
        )
2002
2003

        tbo_preparer = TboDPAttentionPreparer()
2004
2005
2006
2007
2008
2009
        if disable_overlap_schedule:
            group = tp_group.device_group
            device = tp_group.device
        else:
            group = tp_group.cpu_group
            device = "cpu"
2010

Lianmin Zheng's avatar
Lianmin Zheng committed
2011
2012
2013
2014
        local_info = torch.tensor(
            [
                num_tokens,
                can_cuda_graph,
2015
                num_tokens_for_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
2016
                is_extend_in_batch,
2017
2018
2019
                *tbo_preparer.prepare_all_gather(
                    local_batch,
                ),
Lianmin Zheng's avatar
Lianmin Zheng committed
2020
2021
            ],
            dtype=torch.int64,
2022
            device=device,
Lianmin Zheng's avatar
Lianmin Zheng committed
2023
2024
        )
        global_info = torch.empty(
2025
            (dp_size, attn_tp_size, 6),
Lianmin Zheng's avatar
Lianmin Zheng committed
2026
            dtype=torch.int64,
2027
            device=device,
Lianmin Zheng's avatar
Lianmin Zheng committed
2028
        )
2029
        torch.distributed.all_gather_into_tensor(
Lianmin Zheng's avatar
Lianmin Zheng committed
2030
2031
            global_info.flatten(),
            local_info,
2032
            group=group,
2033
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
2034
2035
2036
2037
        global_num_tokens = global_info[:, 0, 0].tolist()
        can_cuda_graph = min(global_info[:, 0, 1].tolist())
        global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
        is_extend_in_batch = global_info[:, 0, 3].tolist()
2038

2039
2040
2041
2042
        tbo_split_seq_index, global_forward_mode = tbo_preparer.compute_output(
            global_info[:, :, 4:6]
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
2043
        if local_batch is None and max(global_num_tokens) > 0:
2044
            local_batch = get_idle_batch()
2045
2046

        if local_batch is not None:
2047
            # TODO: handle the case when moe_dense_tp_size != 1
2048
            if not require_mlp_tp_gather:
2049
2050
2051
2052
2053
2054
2055
                local_batch.global_num_tokens = [num_tokens]
                local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
            else:
                local_batch.global_num_tokens = global_num_tokens
                local_batch.global_num_tokens_for_logprob = (
                    global_num_tokens_for_logprob
                )
2056
            local_batch.is_extend_in_batch = any(is_extend_in_batch)
2057
2058
            local_batch.tbo_split_seq_index = tbo_split_seq_index
            local_batch.global_forward_mode = global_forward_mode
2059

2060
            # Check forward mode for cuda graph
2061
            if not disable_cuda_graph:
Lianmin Zheng's avatar
Lianmin Zheng committed
2062
                local_batch.can_run_dp_cuda_graph = can_cuda_graph
2063

2064
        return local_batch
2065
2066
2067
2068
2069

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
2070
            self.token_to_kv_pool_allocator,
2071
2072
2073
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
2074
            self.spec_algorithm,
2075
2076
2077
2078
        )
        idle_batch.prepare_for_idle()
        return idle_batch

2079
2080
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
2081

2082
        num_ready_reqs = 0
2083
        num_timeout_reqs = 0
2084
2085
        for req in self.grammar_queue:
            try:
2086
2087
2088
                if req.finished():  # It is aborted by AbortReq
                    num_ready_reqs += 1
                    continue
2089
                req.grammar = req.grammar.result(timeout=0.03)
2090
2091
2092
2093
2094
                self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
                if req.grammar is INVALID_GRAMMAR_OBJ:
                    req.set_finish_with_abort(
                        f"Invalid grammar request: {req.grammar_key=}"
                    )
2095
2096
                num_ready_reqs += 1
            except futures._base.TimeoutError:
2097
                req.grammar_wait_ct += 1
2098
2099
                # NOTE(lianmin): this timeout is the waiting time of the above line. It is
                # not the waiting time from it enters the grammar queue.
2100
                if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
2101
                    num_timeout_reqs = 1
2102
2103
                break

2104
        if self.server_args.enable_dp_attention:
2105
2106
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
2107
        else:
2108
2109
2110
2111
2112
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
2113
            tensor = torch.tensor([num_ready_reqs, num_timeout_reqs], dtype=torch.int32)
2114
2115
2116
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
2117
            num_ready_reqs_max, num_timeout_reqs_max = tensor.tolist()
2118

2119
            for i in range(num_ready_reqs, num_ready_reqs_max):
2120
                req = self.grammar_queue[i]
2121
2122
                if req.finished():  # It is aborted by AbortReq
                    continue
2123
                req.grammar = req.grammar.result()
2124
2125
2126
2127
2128
2129
2130
2131
                self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
                if req.grammar is INVALID_GRAMMAR_OBJ:
                    req.set_finish_with_abort(
                        f"Invalid grammar request: {req.grammar_key=}"
                    )
        else:
            num_ready_reqs_max = num_ready_reqs
            num_timeout_reqs_max = num_timeout_reqs
2132

2133
2134
2135
2136
2137
2138
2139
        for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
            req = self.grammar_queue[i]
            req.grammar.cancel()
            error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
            req.set_finish_with_abort(error_msg)
            self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
        num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
2140

2141
        self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
2142
2143
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

2144
2145
2146
2147
2148
2149
2150
    def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
        if batch.next_batch_sampling_info:
            if batch.next_batch_sampling_info.grammars is not None:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
                self.current_stream.synchronize()
            batch.next_batch_sampling_info.sampling_info_done.set()

2151
2152
2153
    def watchdog_thread(self):
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
        self.watchdog_last_forward_ct = 0
2154
        self.watchdog_last_time = time.perf_counter()
2155
2156

        while True:
2157
            current = time.perf_counter()
2158
2159
2160
2161
2162
2163
2164
2165
2166
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if current > self.watchdog_last_time + self.watchdog_timeout:
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)

Lianmin Zheng's avatar
Lianmin Zheng committed
2167
2168
        if not disable_request_logging():
            # Print batch size and memory pool info to check whether there are de-sync issues.
Hanming Lu's avatar
Hanming Lu committed
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
            if self.is_hybrid:
                (
                    _,
                    _,
                    _,
                    _,
                    full_available_size,
                    full_evictable_size,
                    swa_available_size,
                    swa_evictable_size,
                ) = self._get_swa_token_info()
                info_msg = (
                    f"{full_available_size=}, "
                    f"{full_evictable_size=}, "
                    f"{swa_available_size=}, "
                    f"{swa_evictable_size=}, "
                )
            else:
                _, _, available_size, evictable_size = self._get_token_info()
                info_msg = f"{available_size=}, " f"{evictable_size=}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
2189
2190
2191
            logger.error(
                f"{self.cur_batch.batch_size()=}, "
                f"{self.cur_batch.reqs=}, "
Hanming Lu's avatar
Hanming Lu committed
2192
                f"{info_msg}"
Lianmin Zheng's avatar
Lianmin Zheng committed
2193
2194
            )

2195
        pyspy_dump_schedulers()
Lianmin Zheng's avatar
Lianmin Zheng committed
2196
        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
2197
2198
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
2199
2200

        # Wait for some time so that the parent process can print the error.
2201
2202
2203
        time.sleep(5)
        self.parent_process.send_signal(signal.SIGQUIT)

2204
2205
2206
    def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
        success = self.flush_cache()
        return FlushCacheReqOutput(success=success)
2207

2208
    def flush_cache(self):
2209
        """Flush the memory pool and cache."""
2210
2211
2212
2213
2214
        if (
            len(self.waiting_queue) == 0
            and self.running_batch.is_empty()
            and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
        ):
2215
2216
            self.cur_batch = None
            self.last_batch = None
2217
            self.tree_cache.reset()
2218
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
2219
                self.grammar_backend.reset()
2220
            self.req_to_token_pool.clear()
2221
            self.token_to_kv_pool_allocator.clear()
2222
2223
2224

            if not self.spec_algorithm.is_none():
                self.draft_worker.model_runner.req_to_token_pool.clear()
2225
                self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
2226
2227
2228
2229
2230

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
2231
2232
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
2233
2234
2235
2236
2237
2238
2239
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
2240
                f"#running-req: {len(self.running_batch.reqs)}"
2241
2242
2243
2244
            )
            if_success = False
        return if_success

Liangsheng Yin's avatar
Liangsheng Yin committed
2245
2246
    def get_load(self):
        # TODO(lsyin): use dynamically maintained num_waiting_tokens
Hanming Lu's avatar
Hanming Lu committed
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
        if self.is_hybrid:
            load_full = (
                self.full_tokens_per_layer
                - self.token_to_kv_pool_allocator.full_available_size()
                - self.tree_cache.full_evictable_size()
            )
            load_swa = (
                self.swa_tokens_per_layer
                - self.token_to_kv_pool_allocator.swa_available_size()
                - self.tree_cache.swa_evictable_size()
            )
            load = max(load_full, load_swa)
        else:
            load = (
                self.max_total_num_tokens
                - self.token_to_kv_pool_allocator.available_size()
                - self.tree_cache.evictable_size()
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
        load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            load += sum(
                len(req.origin_input_ids)
                for req in self.disagg_prefill_bootstrap_queue.queue
            )
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            load += sum(
                len(req.req.origin_input_ids)
                for req in self.disagg_decode_prealloc_queue.queue
            )

        return load

2279
2280
2281
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
2282
2283
2284
2285
2286
2287
2288
2289
2290
        ret["memory_usage"] = {
            "weight": round(
                self.tp_worker.worker.model_runner.weight_load_mem_usage, 2
            ),
            "kvcache": round(
                self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
            ),
            "token_capacity": int(self.max_total_num_tokens),
        }
2291
2292
2293
2294
2295
2296

        if not _is_cpu:
            ret["memory_usage"]["cuda_graph"] = round(
                self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
            )

2297
2298
2299
2300
2301
2302
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )
        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
Liangsheng Yin's avatar
Liangsheng Yin committed
2303
2304
2305
2306

        ret["load"] = self.get_load()

        return GetInternalStateReqOutput(internal_state=ret)
2307
2308
2309
2310
2311

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
2312
                "max_micro_batch_size",
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
2323
2324
2325
2326
2327
2328
2329
2330
            elif k == "max_micro_batch_size" and (
                v > self.max_running_requests // self.pp_size or v < 1
            ):
                logging.warning(
                    f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
                )
                if_success = False
                break
2331
2332
2333
2334
2335
2336
2337
2338
2339
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
2340
            logger.info(f"Global server args updated! {global_server_args_dict=}")
2341
2342
2343
2344
2345
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
    def handle_rpc_request(self, recv_req: RpcReqInput):
        # Handle RPC requests
        logger.info(
            f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
        )

        success = True
        exec = None
        try:
            func = getattr(self, recv_req.method)
            func(recv_req.parameters)
        except Exception as e:
            success = False
            exec = e
            logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")

        barrier()
        return RpcReqOutput(success, "" if not exec else str(exec))

2365
2366
    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
Lianmin Zheng's avatar
Lianmin Zheng committed
2367
        to_del = []
2368
        for i, req in enumerate(self.waiting_queue):
2369
            if recv_req.abort_all or req.rid.startswith(recv_req.rid):
Lianmin Zheng's avatar
Lianmin Zheng committed
2370
                to_del.append(i)
2371

Lianmin Zheng's avatar
Lianmin Zheng committed
2372
        # Sort in reverse order to avoid index issues when deleting
Lianmin Zheng's avatar
Lianmin Zheng committed
2373
        for i in reversed(to_del):
2374
2375
2376
            # Abort method 1: directly pop from the queue
            # This only works for requests that have not started anything.
            # We still need to send something back to TokenizerManager to clean up the state.
Lianmin Zheng's avatar
Lianmin Zheng committed
2377
            req = self.waiting_queue.pop(i)
Lianmin Zheng's avatar
Lianmin Zheng committed
2378
            self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
2379
            logger.debug(f"Abort queued request. {req.rid=}")
2380

2381
2382
2383
2384
2385
        # Delete the requests in the grammar queue
        for req in self.grammar_queue:
            # Abort method 2: call `set_finish_with_abort`
            # The request will still run one prefill forward pass.
            # In this case, we change the input_ids to be only one token to make this prefill cheap.
2386
            if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2387
                logger.debug(f"Abort grammar queue request. {req.rid=}")
2388
2389
                if req.grammar:
                    req.grammar.cancel()
2390
2391
                req.set_finish_with_abort("Aborted by AbortReq.")

2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
        # Delete requests not in the waiting queue when PD disaggregation is enabled
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            # Abort requests that have not yet been bootstrapped
            for i, req in enumerate(self.disagg_prefill_bootstrap_queue.queue):
                logger.debug(f"Abort bootstrap queue request. {req.rid=}")
                if recv_req.abort_all or req.rid.startswith(recv_req.rid):
                    if hasattr(req.disagg_kv_sender, "abort"):
                        req.disagg_kv_sender.abort()

            # Abort in-flight requests
            for i, req in enumerate(self.disagg_prefill_inflight_queue):
                logger.debug(f"Abort inflight queue request. {req.rid=}")
                if recv_req.abort_all or req.rid.startswith(recv_req.rid):
                    if hasattr(req.disagg_kv_sender, "abort"):
                        req.disagg_kv_sender.abort()

        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            # Abort requests that have not yet finished preallocation
            for i, decode_req in enumerate(self.disagg_decode_prealloc_queue.queue):
                logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
                if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
                    if hasattr(decode_req.kv_receiver, "abort"):
                        decode_req.kv_receiver.abort()

            # Abort requests waiting for kvcache to release tree cache
            for i, decode_req in enumerate(self.disagg_decode_transfer_queue.queue):
                logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
                if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
                    if hasattr(decode_req.kv_receiver, "abort"):
                        decode_req.kv_receiver.abort()

2423
        # Delete requests in the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
2424
2425
2426
2427
2428
2429
        if self.cur_batch is self.running_batch or self.cur_batch is None:
            reqs = self.running_batch.reqs
        else:
            reqs = self.running_batch.reqs + self.cur_batch.reqs

        for req in reqs:
2430
2431
2432
            if not req.finished() and (
                recv_req.abort_all or req.rid.startswith(recv_req.rid)
            ):
2433
2434
2435
                # Abort method 3: set `to_abort=True`
                # The request will still run one decode forward pass.
                # Then we reuse all existing code to clean up the KV cache allocation.
Lianmin Zheng's avatar
Lianmin Zheng committed
2436
2437
                logger.debug(f"Abort running request. {req.rid=}")
                req.to_abort = True
2438

2439
2440
2441
    def _pause_engine(self) -> Tuple[List[Req], int]:
        raise NotImplementedError()

2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
    def load_lora_adapter(
        self, recv_req: LoadLoRAAdapterReqInput
    ) -> LoadLoRAAdapterReqOutput:
        """In-place loading a new lora adapter from disk or huggingface."""

        result = self.tp_worker.load_lora_adapter(recv_req)
        return result

    def unload_lora_adapter(
        self, recv_req: UnloadLoRAAdapterReqInput
    ) -> UnloadLoRAAdapterReqOutput:
        """Unload the lora adapter."""

        result = self.tp_worker.unload_lora_adapter(recv_req)
        return result

2458
2459
2460
2461
2462
2463
2464
    def slow_down(self, recv_req: SlowDownReqInput):
        t = recv_req.forward_sleep_time
        if t is not None and t <= 0:
            t = None
        self.forward_sleep_time = t
        return SlowDownReqOutput()

2465
2466
    def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
        if recv_req == ExpertDistributionReq.START_RECORD:
2467
            get_global_expert_distribution_recorder().start_record()
2468
        elif recv_req == ExpertDistributionReq.STOP_RECORD:
2469
            get_global_expert_distribution_recorder().stop_record()
2470
        elif recv_req == ExpertDistributionReq.DUMP_RECORD:
2471
            get_global_expert_distribution_recorder().dump_record()
2472
        else:
2473
            raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
2474
        return ExpertDistributionReqOutput()
2475

2476
    def open_session(self, recv_req: OpenSessionReqInput):
2477
2478
2479
2480
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
2481
            return OpenSessionReqOutput(session_id, False)
2482
        elif session_id is None:
2483
            logger.warning("session id is None, cannot open.")
2484
            return OpenSessionReqOutput(session_id, False)
2485
2486
2487
2488
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
2489
            return OpenSessionReqOutput(session_id, True)
2490
2491
2492
2493
2494
2495
2496
2497
2498

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

2499
2500
    def get_print_prefix(self):
        prefix = ""
2501
2502
        if self.attn_dp_rank is not None:
            prefix += f" DP{self.attn_dp_rank}"
2503
2504
2505
2506
2507
2508
        if self.server_args.tp_size > 1:
            prefix += f" TP{self.tp_rank}"
        if self.pp_size > 1:
            prefix += f" PP{self.pp_rank}"
        return prefix

2509
2510
    def current_scheduler_metrics_enabled(self):
        return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
2511

2512
2513
2514
    def maybe_sleep_on_idle(self):
        if self.idle_sleeper is not None:
            self.idle_sleeper.maybe_sleep()
2515

2516
2517
2518
2519
2520
2521
    def handle_freeze_gc(self, recv_req: FreezeGCReq):
        """Handle freeze_gc request: freeze scheduler's GC and forward to detokenizer."""
        freeze_gc("Scheduler")
        self.send_to_detokenizer.send_pyobj(recv_req)
        return None

2522

2523
2524
2525
2526
2527
2528
2529
class IdleSleeper:
    """
    In setups which have long inactivity periods it is desirable to reduce
    system power consumption when sglang does nothing. This would lead not only
    to power savings, but also to more CPU thermal headroom when a request
    eventually comes. This is important in cases when multiple GPUs are connected
    as each GPU would otherwise pin one thread at 100% CPU usage.
2530

2531
2532
2533
    The simplest solution is to use zmq.Poller on all sockets that may receive
    data that needs handling immediately.
    """
2534

2535
2536
    def __init__(self, sockets):
        self.poller = zmq.Poller()
2537
        self.last_empty_time = time.time()
2538
2539
2540
2541
2542
        for s in sockets:
            self.poller.register(s, zmq.POLLIN)

    def maybe_sleep(self):
        self.poller.poll(1000)
2543
2544
2545
2546
2547
2548
2549
        if (
            global_config.torch_empty_cache_interval > 0
            and time.time() - self.last_empty_time
            > global_config.torch_empty_cache_interval
        ):
            self.last_empty_time = time.time()
            torch.cuda.empty_cache()
2550

2551

2552
2553
def is_health_check_generate_req(recv_req):
    return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
2554

2555
2556

def is_work_request(recv_req):
2557
2558
2559
2560
2561
2562
2563
2564
2565
    return isinstance(
        recv_req,
        (
            TokenizedGenerateReqInput,
            TokenizedEmbeddingReqInput,
            BatchTokenizedGenerateReqInput,
            BatchTokenizedEmbeddingReqInput,
        ),
    )
2566
2567


2568
2569
2570
2571
2572
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
Cheng Wan's avatar
Cheng Wan committed
2573
    moe_ep_rank: int,
2574
    pp_rank: int,
2575
    dp_rank: Optional[int],
2576
    pipe_writer,
2577
    balance_meta: Optional[DPBalanceMeta] = None,
2578
):
2579
    # Generate the prefix
2580
2581
2582
2583
2584
    prefix = ""
    if dp_rank is not None:
        prefix += f" DP{dp_rank}"
    if server_args.tp_size > 1:
        prefix += f" TP{tp_rank}"
Cheng Wan's avatar
Cheng Wan committed
2585
2586
    if server_args.ep_size > 1:
        prefix += f" EP{moe_ep_rank}"
2587
2588
    if server_args.pp_size > 1:
        prefix += f" PP{pp_rank}"
2589

2590
    # Config the process
2591
    setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
2592
    faulthandler.enable()
2593
    kill_itself_when_parent_died()
2594
    parent_process = psutil.Process().parent()
2595

2596
2597
2598
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
2599

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
2600
    # Configure the logger
2601
    configure_logger(server_args, prefix=prefix)
2602
    suppress_other_loggers()
2603

2604
    # Set cpu affinity to this gpu process
2605
2606
2607
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

2608
    # Create a scheduler and run the event loop
2609
    try:
Cheng Wan's avatar
Cheng Wan committed
2610
        scheduler = Scheduler(
2611
2612
2613
2614
2615
2616
2617
2618
            server_args,
            port_args,
            gpu_id,
            tp_rank,
            moe_ep_rank,
            pp_rank,
            dp_rank,
            dp_balance_meta=balance_meta,
Cheng Wan's avatar
Cheng Wan committed
2619
        )
2620
        pipe_writer.send(
Mick's avatar
Mick committed
2621
2622
2623
2624
2625
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2626
        )
Byron Hsu's avatar
Byron Hsu committed
2627

2628
        disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
Byron Hsu's avatar
Byron Hsu committed
2629
        if disaggregation_mode == DisaggregationMode.NULL:
2630
2631
2632
            if server_args.pp_size > 1:
                scheduler.event_loop_pp()
            elif scheduler.enable_overlap:
Byron Hsu's avatar
Byron Hsu committed
2633
2634
2635
2636
                scheduler.event_loop_overlap()
            else:
                scheduler.event_loop_normal()
        elif disaggregation_mode == DisaggregationMode.PREFILL:
2637
2638
2639
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_prefill()
            else:
2640
2641
2642
2643
                if server_args.pp_size > 1:
                    scheduler.event_loop_pp_disagg_prefill()
                else:
                    scheduler.event_loop_normal_disagg_prefill()
2644

Byron Hsu's avatar
Byron Hsu committed
2645
        elif disaggregation_mode == DisaggregationMode.DECODE:
2646
2647
2648
2649
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_decode()
            else:
                scheduler.event_loop_normal_disagg_decode()
Byron Hsu's avatar
Byron Hsu committed
2650

2651
    except Exception:
2652
2653
2654
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)