scheduler.py 105 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
import time
23
from collections import deque
Lianmin Zheng's avatar
Lianmin Zheng committed
24
from concurrent import futures
25
from dataclasses import dataclass
26
from http import HTTPStatus
27
from types import SimpleNamespace
28
from typing import Dict, List, Optional, Tuple, Union
29

30
import psutil
31
import setproctitle
32
import torch
33
import zmq
34
from torch.distributed import barrier
35

36
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
37
from sglang.srt.configs.model_config import ModelConfig
38
39
40
41
from sglang.srt.constrained.base_grammar_backend import (
    INVALID_GRAMMAR_OBJ,
    create_grammar_backend,
)
Byron Hsu's avatar
Byron Hsu committed
42
43
44
45
46
47
48
49
50
51
52
from sglang.srt.disaggregation.decode import (
    DecodePreallocQueue,
    DecodeTransferQueue,
    SchedulerDisaggregationDecodeMixin,
)
from sglang.srt.disaggregation.prefill import (
    PrefillBootstrapQueue,
    SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
53
    MetadataBuffers,
Byron Hsu's avatar
Byron Hsu committed
54
    ReqToMetadataIdxAllocator,
55
    TransferBackend,
56
    prepare_abort,
Byron Hsu's avatar
Byron Hsu committed
57
)
58
from sglang.srt.distributed import get_pp_group, get_world_group
fzyzcjy's avatar
fzyzcjy committed
59
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
xm:D's avatar
xm:D committed
60
61
62
63
64
from sglang.srt.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
65
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
66
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
67
from sglang.srt.layers.moe import initialize_moe_config
68
69
from sglang.srt.managers.io_struct import (
    AbortReq,
70
    CloseSessionReqInput,
71
    ExpertDistributionReq,
72
    ExpertDistributionReqOutput,
73
74
    FlushCacheReqInput,
    FlushCacheReqOutput,
75
76
    GetInternalStateReq,
    GetInternalStateReqOutput,
77
    GetWeightsByNameReqInput,
78
    HealthCheckOutput,
79
    InitWeightsUpdateGroupReqInput,
80
81
    LoadLoRAAdapterReqInput,
    LoadLoRAAdapterReqOutput,
82
83
    OpenSessionReqInput,
    OpenSessionReqOutput,
84
    ProfileReq,
85
86
    ReleaseMemoryOccupationReqInput,
    ResumeMemoryOccupationReqInput,
87
88
    RpcReqInput,
    RpcReqOutput,
89
90
    SetInternalStateReq,
    SetInternalStateReqOutput,
91
92
    SlowDownReqInput,
    SlowDownReqOutput,
93
94
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
95
96
    UnloadLoRAAdapterReqInput,
    UnloadLoRAAdapterReqOutput,
Chayenne's avatar
Chayenne committed
97
    UpdateWeightFromDiskReqInput,
98
    UpdateWeightsFromDistributedReqInput,
99
    UpdateWeightsFromTensorReqInput,
100
)
101
from sglang.srt.managers.mm_utils import init_embedding_cache
102
103
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
Mick's avatar
Mick committed
104
    MultimodalInputs,
105
106
    Req,
    ScheduleBatch,
107
    global_server_args_dict,
108
)
109
110
111
112
113
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
fzyzcjy's avatar
fzyzcjy committed
114
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
115
116
117
118
from sglang.srt.managers.scheduler_metrics_mixin import (
    RECORD_STEP_TIME,
    SchedulerMetricsMixin,
)
119
120
121
from sglang.srt.managers.scheduler_output_processor_mixin import (
    SchedulerOutputProcessorMixin,
)
122
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
123
from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
124
125
126
from sglang.srt.managers.scheduler_update_weights_mixin import (
    SchedulerUpdateWeightsMixin,
)
127
from sglang.srt.managers.session_controller import Session
128
from sglang.srt.managers.tp_worker import TpModelWorker
129
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
130
from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
tarinkk's avatar
tarinkk committed
131
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
132
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
133
from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
134
from sglang.srt.mem_cache.radix_cache import RadixCache
Hanming Lu's avatar
Hanming Lu committed
135
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
Lianmin Zheng's avatar
Lianmin Zheng committed
136
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
137
from sglang.srt.reasoning_parser import ReasoningParser
138
from sglang.srt.server_args import PortArgs, ServerArgs
139
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
140
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
141
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
142
from sglang.srt.utils import (
143
    DynamicGradMode,
144
    broadcast_pyobj,
fzyzcjy's avatar
fzyzcjy committed
145
    configure_gc_logger,
146
    configure_logger,
Lianmin Zheng's avatar
Lianmin Zheng committed
147
    disable_request_logging,
148
    get_available_gpu_memory,
149
    get_bool_env_var,
150
    get_zmq_socket,
151
    is_cpu,
Lianmin Zheng's avatar
Lianmin Zheng committed
152
    kill_itself_when_parent_died,
153
    point_to_point_pyobj,
154
    pyspy_dump_schedulers,
155
156
    require_mlp_sync,
    require_mlp_tp_gather,
157
    set_gpu_proc_affinity,
158
159
160
    set_random_seed,
    suppress_other_loggers,
)
161
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
162
163
164

logger = logging.getLogger(__name__)

165
# Test retract decode for debugging purposes
166
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
167
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
168

169
170
_is_cpu = is_cpu()

171

172
173
@dataclass
class GenerationBatchResult:
174
175
176
    logits_output: Optional[LogitsProcessorOutput]
    pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
    next_token_ids: Optional[List[int]]
177
178
    extend_input_len_per_req: List[int]
    extend_logprob_start_len_per_req: List[int]
179
    bid: int
180
    can_run_cuda_graph: bool
181
182
183
184
185
186
187
188


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


Byron Hsu's avatar
Byron Hsu committed
189
190
class Scheduler(
    SchedulerOutputProcessorMixin,
191
192
193
    SchedulerUpdateWeightsMixin,
    SchedulerProfilerMixin,
    SchedulerMetricsMixin,
Byron Hsu's avatar
Byron Hsu committed
194
195
196
    SchedulerDisaggregationDecodeMixin,
    SchedulerDisaggregationPrefillMixin,
):
197
198
199
200
201
202
203
204
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
Cheng Wan's avatar
Cheng Wan committed
205
        moe_ep_rank: int,
206
        pp_rank: int,
207
        dp_rank: Optional[int],
208
        dp_balance_meta: Optional[DPBalanceMeta] = None,
209
210
    ):
        # Parse args
211
        self.server_args = server_args
212
        self.tp_rank = tp_rank
Cheng Wan's avatar
Cheng Wan committed
213
        self.moe_ep_rank = moe_ep_rank
214
        self.pp_rank = pp_rank
215
        self.dp_rank = dp_rank
216
        self.tp_size = server_args.tp_size
Cheng Wan's avatar
Cheng Wan committed
217
        self.moe_ep_size = server_args.ep_size
218
219
        self.pp_size = server_args.pp_size
        self.dp_size = server_args.dp_size
220
        self.schedule_policy = server_args.schedule_policy
221
        self.enable_lora = server_args.enable_lora
222
        self.max_loras_per_batch = server_args.max_loras_per_batch
223
        self.enable_overlap = not server_args.disable_overlap_schedule
224
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
225
        self.enable_metrics = server_args.enable_metrics
226
227
228
        self.enable_metrics_for_all_schedulers = (
            server_args.enable_metrics_for_all_schedulers
        )
229
        self.enable_kv_cache_events = server_args.kv_events_config is not None
230
        self.stream_interval = server_args.stream_interval
231
232
233
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
234
235
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
236
        self.enable_hicache_storage = server_args.hicache_storage_backend is not None
Lianmin Zheng's avatar
Lianmin Zheng committed
237
        self.page_size = server_args.page_size
238

239
        self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
240
241
242
243
244
245
246
247
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

248
249
250
        # Init model config
        self.model_config = ModelConfig.from_server_args(server_args)

251
252
        # Init inter-process communication
        context = zmq.Context(2)
253
        self.idle_sleeper = None
254

255
        if self.pp_rank == 0 and self.attn_tp_rank == 0:
256
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
257
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
258
            )
259
260
261
262
            self.recv_from_rpc = get_zmq_socket(
                context, zmq.DEALER, port_args.rpc_ipc_name, False
            )

263
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
264
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
265
            )
266
            if server_args.skip_tokenizer_init:
267
                # Directly send to the TokenizerManager
268
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
269
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
270
271
                )
            else:
272
                # Send to the DetokenizerManager
273
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
274
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
275
                )
276

277
278
279
280
281
282
283
            if self.server_args.sleep_on_idle:
                self.idle_sleeper = IdleSleeper(
                    [
                        self.recv_from_tokenizer,
                        self.recv_from_rpc,
                    ]
                )
284
        else:
285
            self.recv_from_tokenizer = None
286
            self.recv_from_rpc = None
287
288
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
289

290
291
292
293
294
        if self.current_scheduler_metrics_enabled():
            self.send_metrics_from_scheduler = get_zmq_socket(
                context, zmq.PUSH, port_args.metrics_ipc_name, False
            )

295
        # Init tokenizer
296
        self.init_tokenizer()
297

298
299
300
        # Init moe config
        self.init_moe_config()

301
302
303
304
305
306
307
308
309
        # Set reasoning_parser and think_end_id if --reasoning_parser is enabled
        if self.server_args.reasoning_parser and self.tokenizer:
            reasoning_parser = ReasoningParser(
                model_type=self.server_args.reasoning_parser, stream_reasoning=False
            )
            self.tokenizer.think_end_id = self.tokenizer.encode(
                reasoning_parser.detector.think_end_token, add_special_tokens=False
            )[0]

310
311
312
313
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
314

315
        # Launch a tensor parallel worker
316
        if self.enable_overlap:
317
            TpWorkerClass = TpModelWorkerClient
318
319
        else:
            TpWorkerClass = TpModelWorker
320

321
        self.tp_worker = TpWorkerClass(
322
            server_args=server_args,
323
324
            gpu_id=gpu_id,
            tp_rank=tp_rank,
Cheng Wan's avatar
Cheng Wan committed
325
            moe_ep_rank=moe_ep_rank,
326
            pp_rank=pp_rank,
327
            dp_rank=dp_rank,
328
            nccl_port=port_args.nccl_port,
329
        )
330

331
        # Launch a draft worker for speculative decoding
332
333
334
335
336
337
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
Cheng Wan's avatar
Cheng Wan committed
338
                moe_ep_rank=moe_ep_rank,
339
340
341
342
343
344
345
346
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
        else:
            self.draft_worker = None

347
        # Get token and memory info from the model worker
348
349
350
351
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
352
            self.max_queued_requests,
353
            self.max_req_len,
354
355
            self.max_req_input_len,
            self.random_seed,
356
            self.device,
357
358
359
360
361
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
362
363
364
365
366
367
368
369
        if global_server_args_dict["max_micro_batch_size"] is None:
            global_server_args_dict["max_micro_batch_size"] = max(
                self.max_running_requests // server_args.pp_size, 1
            )

        self.tp_group = self.tp_worker.get_tp_group()
        self.tp_cpu_group = self.tp_group.cpu_group
        self.attn_tp_group = self.tp_worker.get_attention_tp_group()
370
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
371
372
373
        self.pp_group = get_pp_group()
        self.world_group = get_world_group()

374
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
375
        global_server_args_dict.update(worker_global_server_args_dict)
376
        set_random_seed(self.random_seed)
377

378
        # Hybrid memory pool
Hanming Lu's avatar
Hanming Lu committed
379
380
381
382
383
384
385
        self.is_hybrid = self.tp_worker.is_hybrid
        if self.is_hybrid:
            self.sliding_window_size = self.tp_worker.sliding_window_size
            self.full_tokens_per_layer, self.swa_tokens_per_layer = (
                self.tp_worker.get_tokens_per_layer_info()
            )

386
        # Print debug info
387
        if tp_rank == 0:
388
389
390
            avail_mem = get_available_gpu_memory(
                self.device, self.gpu_id, empty_cache=False
            )
391
392
393
394
395
            logger.info(
                f"max_total_num_tokens={self.max_total_num_tokens}, "
                f"chunked_prefill_size={server_args.chunked_prefill_size}, "
                f"max_prefill_tokens={self.max_prefill_tokens}, "
                f"max_running_requests={self.max_running_requests}, "
396
397
                f"context_len={self.model_config.context_len}, "
                f"available_gpu_mem={avail_mem:.2f} GB"
398
            )
399

Lianmin Zheng's avatar
Lianmin Zheng committed
400
        # Init memory pool and cache
401
        self.init_memory_pool_and_cache()
402
403
404

        # Init running status
        self.waiting_queue: List[Req] = []
405
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
406
        self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
407
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
408
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
409
        # The last forward batch
410
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
411
412
        self.forward_ct = 0
        self.forward_ct_decode = 0
413
        self.num_generated_tokens = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
414
        self.last_prefill_tokens = 0
415
416
        self.last_decode_stats_tic = time.perf_counter()
        self.last_prefill_stats_tic = time.perf_counter()
417
        self.return_health_check_ct = 0
418
419
420
421
422
        self.num_retracted_reqs: int = 0
        self.num_paused_reqs: int = 0
        self.kv_transfer_speed_gb_s: float = 0.0
        self.kv_transfer_latency_ms: float = 0.0
        self.sessions: Dict[str, Session] = {}
423
        self.current_stream = torch.get_device_module(self.device).current_stream()
424
425
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
426
        self.forward_sleep_time = None
427

428
429
        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
430
431
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
432
        self.chunked_req = None
433
434
435
436
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
437
        # Init the grammar backend for constrained generation
438
        self.grammar_queue: List[Req] = []
439
        if not server_args.skip_tokenizer_init:
440
            self.grammar_backend = create_grammar_backend(
441
442
443
444
                server_args,
                self.tokenizer,
                self.model_config.vocab_size,
                self.model_config.hf_eos_token_id,
445
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
446
447
        else:
            self.grammar_backend = None
448

449
        # Init schedule policy and new token estimation
450
        self.policy = SchedulePolicy(
Lianmin Zheng's avatar
Lianmin Zheng committed
451
452
453
            self.schedule_policy,
            self.tree_cache,
            self.enable_hierarchical_cache,
454
        )
455
456
457
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
458
459
        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
460
461
            * server_args.schedule_conservativeness,
            1.0,
462
        )
463
464
465
466
467
468
469
470
471
472
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
473
474
475
476
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
477
        self.parent_process = psutil.Process().parent()
478
479

        # Init memory saver, profiler and metric stats
480
481
482
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )
483
        self.offload_tags = set()
484
        self.init_profier()
485

486
        self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
fzyzcjy's avatar
fzyzcjy committed
487
488
489
490
491
492
        self.input_blocker = (
            SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
            if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
            else None
        )

493
        # Init metrics stats
494
        self.init_metrics(tp_rank, pp_rank, dp_rank)
495
        self.init_kv_events(server_args.kv_events_config)
496

497
498
499
500
501
502
503
504
505
        # Init disaggregation
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
        self.init_disaggregation()

        if get_bool_env_var("SGLANG_GC_LOG"):
            configure_gc_logger()

506
507
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
508
509
510
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
511
                (FlushCacheReqInput, self.flush_cache_wrapped),
512
                (AbortReq, self.abort_request),
513
514
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
515
516
517
518
519
520
521
522
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
523
524
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
525
                (SlowDownReqInput, self.slow_down),
526
                (ProfileReq, self.profile),
527
                (GetInternalStateReq, self.get_internal_state),
528
                (SetInternalStateReq, self.set_internal_state),
529
                (RpcReqInput, self.handle_rpc_request),
530
                (ExpertDistributionReq, self.expert_distribution_handle),
531
532
                (LoadLoRAAdapterReqInput, self.load_lora_adapter),
                (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
533
534
535
            ]
        )

536
537
538
539
540
541
542
543
544
        self.balance_meta = dp_balance_meta
        if (
            server_args.enable_dp_attention
            and server_args.load_balance_method == "minimum_tokens"
        ):
            assert dp_balance_meta is not None

        self.recv_dp_balance_id_this_term = []

545
546
547
    def init_tokenizer(self):
        server_args = self.server_args
        self.is_generation = self.model_config.is_generation
548

549
550
551
552
553
554
555
556
557
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if self.model_config.is_multimodal:
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
558
                    use_fast=not server_args.disable_fast_image_processor,
559
                )
xm:D's avatar
xm:D committed
560
                self.tokenizer = get_tokenizer_from_processor(self.processor)
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )

    def init_memory_pool_and_cache(self):
        server_args = self.server_args

        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            self.tp_worker.get_memory_pool()
        )

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
Hanming Lu's avatar
Hanming Lu committed
580
            if self.is_hybrid:
tarinkk's avatar
tarinkk committed
581
582
583
584
                ChunkCacheClass = SWAChunkCache
            else:
                ChunkCacheClass = ChunkCache
            self.tree_cache = ChunkCacheClass(
585
586
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
587
                page_size=self.page_size,
588
589
            )
        else:
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
            if os.environ.get("SGLANG_EXPERIMENTAL_CPP_RADIX_TREE") == "1":
                # lazy import to avoid JIT overhead
                from sglang.srt.mem_cache.radix_cache_cpp import RadixCacheCpp

                self.tree_cache = RadixCacheCpp(
                    disable=False,
                    use_hicache=self.enable_hierarchical_cache,
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool=self.token_to_kv_pool_allocator,
                    tp_cache_group=self.tp_cpu_group,
                    page_size=self.page_size,
                    hicache_ratio=server_args.hicache_ratio,
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
                    enable_kv_cache_events=self.enable_kv_cache_events,
                )
            elif self.enable_hierarchical_cache:
607
608
609
                self.tree_cache = HiRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
610
611
612
613
614
                    tp_cache_group=(
                        self.attn_tp_cpu_group
                        if self.server_args.enable_dp_attention
                        else self.tp_cpu_group
                    ),
615
                    page_size=self.page_size,
616
                    hicache_ratio=server_args.hicache_ratio,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
617
618
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
619
                    hicache_io_backend=server_args.hicache_io_backend,
620
                    hicache_mem_layout=server_args.hicache_mem_layout,
621
                    hicache_storage_backend=server_args.hicache_storage_backend,
pansicheng's avatar
pansicheng committed
622
                    hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
623
                )
624
625
626
                self.tp_worker.register_hicache_layer_transfer_counter(
                    self.tree_cache.cache_controller.layer_done_counter
                )
Hanming Lu's avatar
Hanming Lu committed
627
628
629
630
631
632
633
634
635
636
637
            elif self.is_hybrid:
                assert (
                    self.server_args.disaggregation_mode == "null"
                ), "Hybrid mode does not support disaggregation yet"
                self.tree_cache = SWARadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                    sliding_window_size=self.sliding_window_size,
                    page_size=self.page_size,
                    disable=server_args.disable_radix_cache,
                )
638
639
640
641
642
643
644
645
646
647
648
649
650
            elif self.enable_lora:
                assert (
                    not self.enable_hierarchical_cache
                ), "LoRA radix cache doesn't support hierarchical cache"
                assert (
                    self.schedule_policy == "fcfs"
                ), "LoRA radix cache only supports FCFS policy"
                self.tree_cache = LoRARadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                    page_size=self.page_size,
                    disable=server_args.disable_radix_cache,
                )
651
652
653
654
            else:
                self.tree_cache = RadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
655
                    page_size=self.page_size,
656
                    disable=server_args.disable_radix_cache,
657
                    enable_kv_cache_events=self.enable_kv_cache_events,
658
659
660
661
662
663
664
665
666
667
668
669
                )

        self.decode_mem_cache_buf_multiplier = (
            1
            if self.spec_algorithm.is_none()
            else (
                server_args.speculative_num_draft_tokens
                + (
                    server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                )
            )
670
        )
671

672
673
674
        embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
        init_embedding_cache(embedding_cache_size * 1024 * 1024)

Byron Hsu's avatar
Byron Hsu committed
675
    def init_disaggregation(self):
676
677
678
679
        self.transfer_backend = TransferBackend(
            self.server_args.disaggregation_transfer_backend
        )

Byron Hsu's avatar
Byron Hsu committed
680
681
682
683
        if (
            self.disaggregation_mode == DisaggregationMode.DECODE
        ):  # *2 for the headroom.
            buffer_size = (self.req_to_token_pool.size) * 2
Byron Hsu's avatar
Byron Hsu committed
684
            self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
Byron Hsu's avatar
Byron Hsu committed
685
686
                buffer_size
            )
687
688
            self.disagg_metadata_buffers = MetadataBuffers(
                buffer_size,
689
690
                hidden_size=self.model_config.hf_text_config.hidden_size,
                dtype=self.model_config.dtype,
691
692
                custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
            )
Byron Hsu's avatar
Byron Hsu committed
693
694
695

            # The decode requests polling kv cache
            self.disagg_decode_transfer_queue = DecodeTransferQueue(
696
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
697
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
698
                tp_rank=self.tp_rank,
699
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
700
701
                scheduler=self,
                tree_cache=self.tree_cache,
Byron Hsu's avatar
Byron Hsu committed
702
703
704
705
706
707
            )

            # The decode requests pending for pre-allocation
            self.disagg_decode_prealloc_queue = DecodePreallocQueue(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Byron Hsu's avatar
Byron Hsu committed
708
709
710
711
712
                draft_token_to_kv_pool=(
                    None
                    if self.draft_worker is None
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
713
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
714
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
715
716
717
                scheduler=self,
                transfer_queue=self.disagg_decode_transfer_queue,
                tree_cache=self.tree_cache,
718
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
719
720
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
721
722
                dp_size=self.server_args.dp_size,
                gpu_id=self.gpu_id,
Byron Hsu's avatar
Byron Hsu committed
723
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
724
725
                max_total_num_tokens=self.max_total_num_tokens,
                prefill_pp_size=self.server_args.disaggregation_prefill_pp,
726
                num_reserved_decode_tokens=self.server_args.num_reserved_decode_tokens,
727
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
728
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
729

Byron Hsu's avatar
Byron Hsu committed
730
731
732
        elif self.disaggregation_mode == DisaggregationMode.PREFILL:
            # *2 for the headroom.
            buffer_size = self.max_running_requests * 2
Byron Hsu's avatar
Byron Hsu committed
733
            self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
Byron Hsu's avatar
Byron Hsu committed
734
735
                buffer_size
            )
736
737
            self.disagg_metadata_buffers = MetadataBuffers(
                buffer_size,
738
739
                hidden_size=self.model_config.hf_text_config.hidden_size,
                dtype=self.model_config.dtype,
740
741
                custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
            )
Byron Hsu's avatar
Byron Hsu committed
742

Liangsheng Yin's avatar
Liangsheng Yin committed
743
            self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
Byron Hsu's avatar
Byron Hsu committed
744
                token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
Byron Hsu's avatar
Byron Hsu committed
745
746
747
748
749
                draft_token_to_kv_pool=(
                    None
                    if self.draft_worker is None
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
750
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
751
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
752
753
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
Byron Hsu's avatar
Byron Hsu committed
754
                gpu_id=self.gpu_id,
Byron Hsu's avatar
Byron Hsu committed
755
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
756
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
757
758
759
                max_total_num_tokens=self.max_total_num_tokens,
                decode_tp_size=self.server_args.disaggregation_decode_tp,
                decode_dp_size=self.server_args.disaggregation_decode_dp,
760
                scheduler=self,
Byron Hsu's avatar
Byron Hsu committed
761
762
763
                pp_rank=self.pp_rank,
                pp_size=self.pp_size,
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
764
765
            )
            # The prefill requests that are in the middle of kv sending
766
            self.disagg_prefill_inflight_queue: List[Req] = []
Byron Hsu's avatar
Byron Hsu committed
767

768
769
770
771
    def init_moe_config(self):
        if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
            initialize_moe_config(self.server_args)

772
    @DynamicGradMode()
773
    def event_loop_normal(self):
774
        """A normal scheduler loop."""
775
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
776
777
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
778

779
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
780
            self.cur_batch = batch
781
782
783
784

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
785
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
786
                # When the server is idle, do self-check and re-init some states
787
                self.self_check_during_idle()
788
789

            self.last_batch = batch
790

791
    @DynamicGradMode()
Lianmin Zheng's avatar
Lianmin Zheng committed
792
    def event_loop_overlap(self):
793
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
794
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
795
796
797
798
799
800
801

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
802

Lianmin Zheng's avatar
Lianmin Zheng committed
803
            if batch:
804
                batch.launch_done = threading.Event()
Lianmin Zheng's avatar
Lianmin Zheng committed
805
                result = self.run_batch(batch)
806
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
807

808
                if self.last_batch is None:
809
                    # Create a dummy first batch to start the pipeline for overlap schedule.
810
811
812
813
814
815
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
816
                    self.process_batch_result(tmp_batch, None, batch.launch_done)
817

Lianmin Zheng's avatar
Lianmin Zheng committed
818
            if self.last_batch:
819
                # Process the results of the last batch
820
                tmp_batch, tmp_result = self.result_queue.popleft()
821
822
823
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
824
825
826
827
                # NOTE: we should use current launched batch's launch_done event Instead of the last batch's
                self.process_batch_result(
                    tmp_batch, tmp_result, batch.launch_done if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
828
            elif batch is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
829
                # When the server is idle, do self-check and re-init some states
830
                self.self_check_during_idle()
Lianmin Zheng's avatar
Lianmin Zheng committed
831
832
833

            self.last_batch = batch

834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
    @DynamicGradMode()
    def event_loop_pp(self):
        """A non-overlap scheduler loop for pipeline parallelism."""
        mbs = [None] * self.pp_size
        last_mbs = [None] * self.pp_size
        self.running_mbs = [
            ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
        ]
        bids = [None] * self.pp_size
        pp_outputs: Optional[PPProxyTensors] = None
        while True:
            server_is_idle = True
            for mb_id in range(self.pp_size):
                self.running_batch = self.running_mbs[mb_id]
                self.last_batch = last_mbs[mb_id]

                recv_reqs = self.recv_requests()
                self.process_input_requests(recv_reqs)
                mbs[mb_id] = self.get_next_batch_to_run()
                self.running_mbs[mb_id] = self.running_batch

                self.cur_batch = mbs[mb_id]
                if self.cur_batch:
                    server_is_idle = False
                    result = self.run_batch(self.cur_batch)

860
                # (last rank) send the outputs to the next step
861
862
863
864
865
866
                if self.pp_group.is_last_rank:
                    if self.cur_batch:
                        next_token_ids, bids[mb_id] = (
                            result.next_token_ids,
                            result.bid,
                        )
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
                        if self.cur_batch.return_logprob:
                            pp_outputs = PPProxyTensors(
                                {
                                    "next_token_ids": next_token_ids,
                                    "extend_input_len_per_req": result.extend_input_len_per_req,
                                    "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
                                }
                                | (
                                    {
                                        f"logits_output.{k}": v
                                        for k, v in result.logits_output.__dict__.items()
                                    }
                                    if result.logits_output is not None
                                    else {}
                                )
                            )
                        else:
                            pp_outputs = PPProxyTensors(
                                {
                                    "next_token_ids": next_token_ids,
                                }
                            )
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
                        # send the output from the last round to let the next stage worker run post processing
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                # receive outputs and post-process (filter finished reqs) the coming microbatch
                next_mb_id = (mb_id + 1) % self.pp_size
                next_pp_outputs = None
                if mbs[next_mb_id] is not None:
                    next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
                        self.pp_group.recv_tensor_dict(
                            all_gather_group=self.attn_tp_group
                        )
                    )
                    mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
905
906
907
908
909
910
911
912
913
                    logits_output_args = {
                        k[len("logits_output.") :]: v
                        for k, v in next_pp_outputs.tensors.items()
                        if k.startswith("logits_output.")
                    }
                    if len(logits_output_args) > 0:
                        logits_output = LogitsProcessorOutput(**logits_output_args)
                    else:
                        logits_output = None
914
                    output_result = GenerationBatchResult(
915
                        logits_output=logits_output,
916
917
                        pp_hidden_states_proxy_tensors=None,
                        next_token_ids=next_pp_outputs["next_token_ids"],
918
919
920
921
922
923
                        extend_input_len_per_req=next_pp_outputs.tensors.get(
                            "extend_input_len_per_req", None
                        ),
                        extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
                            "extend_logprob_start_len_per_req", None
                        ),
924
                        bid=bids[next_mb_id],
925
                        can_run_cuda_graph=result.can_run_cuda_graph,
926
927
928
929
                    )
                    self.process_batch_result(mbs[next_mb_id], output_result)
                    last_mbs[next_mb_id] = mbs[next_mb_id]

930
                # (not last rank)
931
932
933
                if not self.pp_group.is_last_rank:
                    if self.cur_batch:
                        bids[mb_id] = result.bid
934
935
                    # carry the outputs to the next stage
                    # send the outputs from the last round to let the next stage worker run post processing
936
937
938
939
940
941
942
                    if pp_outputs:
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                    # send out reqs to the next stage
943
                    dp_offset = self.attn_dp_rank * self.attn_tp_size
944
945
946
947
                    if self.attn_tp_rank == 0:
                        point_to_point_pyobj(
                            recv_reqs,
                            self.pp_rank * self.tp_size + dp_offset,
948
                            self.world_group.device_group,
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
                            self.pp_rank * self.tp_size + dp_offset,
                            (self.pp_rank + 1) * self.tp_size + dp_offset,
                        )

                    # send out proxy tensors to the next stage
                    if self.cur_batch:
                        self.pp_group.send_tensor_dict(
                            result.pp_hidden_states_proxy_tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                pp_outputs = next_pp_outputs

            # When the server is idle, self-check and re-init some states
            if server_is_idle:
964
965
                # When the server is idle, do self-check and re-init some states
                self.self_check_during_idle()
966

967
968
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
969
970
971
972
973
974
975
976

        if self.recv_skipper is not None:
            last_forward_mode = (
                self.last_batch.forward_mode if self.last_batch is not None else None
            )
            if not self.recv_skipper.handle(last_forward_mode):
                return []

977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
        if self.pp_rank == 0:
            if self.attn_tp_rank == 0:
                recv_reqs = []

                while True:
                    try:
                        recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_req)

                while True:
                    try:
                        recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_rpc)
            else:
                recv_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
996
        else:
997
            if self.attn_tp_rank == 0:
998
                dp_offset = self.attn_dp_rank * self.attn_tp_size
999
1000
1001
                recv_reqs = point_to_point_pyobj(
                    [],
                    self.pp_rank * self.tp_size + dp_offset,
1002
                    self.world_group.device_group,
1003
1004
1005
1006
1007
                    (self.pp_rank - 1) * self.tp_size + dp_offset,
                    self.pp_rank * self.tp_size + dp_offset,
                )
            else:
                recv_reqs = None
1008

fzyzcjy's avatar
fzyzcjy committed
1009
1010
1011
        if self.input_blocker is not None:
            recv_reqs = self.input_blocker.handle(recv_reqs)

1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                work_reqs = broadcast_pyobj(
                    work_reqs,
1035
                    self.attn_tp_group.rank,
1036
                    self.attn_tp_cpu_group,
1037
                    src=self.attn_tp_group.ranks[0],
1038
1039
1040
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
1041
1042
1043
1044
                    control_reqs,
                    self.tp_group.rank,
                    self.tp_cpu_group,
                    src=self.tp_group.ranks[0],
1045
1046
1047
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
1048
1049
1050
1051
1052
1053
            recv_reqs = broadcast_pyobj(
                recv_reqs,
                self.tp_group.rank,
                self.tp_cpu_group,
                src=self.tp_group.ranks[0],
            )
1054
1055
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
1056
    def process_input_requests(self, recv_reqs: List):
1057
        for recv_req in recv_reqs:
1058
1059
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
1060
1061
1062
                self.chunked_req is not None
                or not self.running_batch.is_empty()
                or len(self.offload_tags) > 0
1063
1064
1065
1066
            ):
                self.return_health_check_ct += 1
                continue

1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
            # If it is a work request, accept or reject the request based on the request queue size.
            if is_work_request(recv_req):
                if len(self.waiting_queue) + 1 > self.max_queued_requests:
                    abort_req = AbortReq(
                        recv_req.rid,
                        finished_reason={
                            "type": "abort",
                            "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
                            "message": "The request queue is full.",
                        },
                    )
                    self.send_to_tokenizer.send_pyobj(abort_req)
                    continue
1080
            output = self._request_dispatcher(recv_req)
1081
            if output is not None:
1082
1083
1084
1085
1086
                if isinstance(output, RpcReqOutput):
                    if self.recv_from_rpc is not None:
                        self.recv_from_rpc.send_pyobj(output)
                else:
                    self.send_to_tokenizer.send_pyobj(output)
1087
1088
1089
1090
1091

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
1092
1093
1094
1095
1096
1097
        if (
            self.server_args.enable_dp_attention
            and self.server_args.load_balance_method == "minimum_tokens"
        ):
            self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)

1098
        # Create a new request
1099
1100
1101
1102
1103
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
1104
1105
1106
1107
1108
1109
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

1110
1111
1112
1113
            if recv_req.bootstrap_port is None:
                # Use default bootstrap port
                recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port

1114
1115
1116
1117
1118
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
1119
1120
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
1121
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
1122
                stream=recv_req.stream,
1123
                lora_id=recv_req.lora_id,
Rin Intachuen's avatar
Rin Intachuen committed
1124
                input_embeds=recv_req.input_embeds,
Lianmin Zheng's avatar
Lianmin Zheng committed
1125
                custom_logit_processor=recv_req.custom_logit_processor,
1126
                return_hidden_states=recv_req.return_hidden_states,
1127
                eos_token_ids=self.model_config.hf_eos_token_id,
1128
                bootstrap_host=recv_req.bootstrap_host,
1129
                bootstrap_port=recv_req.bootstrap_port,
1130
                bootstrap_room=recv_req.bootstrap_room,
1131
                data_parallel_rank=recv_req.data_parallel_rank,
1132
                vocab_size=self.model_config.vocab_size,
1133
1134
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
1135

1136
1137
1138
            if self.disaggregation_mode != DisaggregationMode.NULL:
                # Invalid request for disaggregated mode
                if recv_req.bootstrap_room is None:
1139
                    error_msg = (
1140
1141
1142
                        f"Invalid request: Disaggregated request received without "
                        f"boostrap room id. {req.rid=}"
                    )
1143
1144
                    logger.error(error_msg)
                    prepare_abort(req, error_msg)
1145
1146
1147
                    self.stream_output([req], req.return_logprob)
                    return

1148
1149
1150
1151
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
1152
                req.set_finish_with_abort(
1153
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
1154
                )
1155
                self._add_request_to_queue(req)
1156
1157
                return
        else:
1158
1159
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
1160
            req = session.create_req(recv_req, self.tokenizer)
1161
            if isinstance(req.finished_reason, FINISH_ABORT):
1162
                self._add_request_to_queue(req)
1163
                return
1164

1165
        # Handle multimodal inputs
Mick's avatar
Mick committed
1166
1167
        if recv_req.mm_inputs is not None:
            image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
1168
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
1169
            req.origin_input_ids = self.pad_input_ids_func(
1170
                req.origin_input_ids, image_inputs
1171
            )
1172
            req.extend_image_inputs(image_inputs)
1173

1174
            if len(req.origin_input_ids) >= self.max_req_input_len:
1175
1176
1177
1178
1179
                req.set_finish_with_abort(
                    error_msg=(
                        "Multimodal prompt is too long after expanding multimodal tokens. "
                        f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                    )
1180
                )
1181
                self._add_request_to_queue(req)
1182
1183
                return

1184
        # Validate prompt length
1185
1186
1187
1188
1189
1190
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
1191
            req.set_finish_with_abort(error_msg)
1192
            self._add_request_to_queue(req)
1193
            return
1194

1195
        # Copy more attributes
1196
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
1197
1198
1199
1200
1201
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(req.origin_input_ids) - 1
        else:
            req.logprob_start_len = recv_req.logprob_start_len

1202
        if req.logprob_start_len >= len(req.origin_input_ids):
1203
            error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
1204
            req.logprob_start_len = len(req.origin_input_ids) - 1
1205
            req.set_finish_with_abort(error_msg)
1206
1207
1208
            self._add_request_to_queue(req)
            return

1209
1210
1211
1212
1213
1214
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
1215
            self.max_req_len - len(req.origin_input_ids) - 1,
1216
1217
        )

1218
1219
1220
1221
1222
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
1223
            or req.sampling_params.ebnf is not None
1224
            or req.sampling_params.structural_tag is not None
1225
1226
1227
1228
1229
1230
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
1231
1232
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
1233
1234
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
1235

1236
1237
1238
1239
1240
            value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
            req.grammar = value

            if not cache_hit:
                req.grammar_key = key
1241
                add_to_grammar_queue = True
1242
1243
1244
1245
            else:
                if value is INVALID_GRAMMAR_OBJ:  # We hit a cached invalid grammar.
                    error_msg = f"Invalid grammar request with cache hit: {key=}"
                    req.set_finish_with_abort(error_msg)
1246
1247

        if add_to_grammar_queue:
1248
            req.queue_time_start = time.perf_counter()
1249
1250
            self.grammar_queue.append(req)
        else:
1251
1252
1253
            self._add_request_to_queue(req)

    def _add_request_to_queue(self, req: Req):
1254
        req.queue_time_start = time.perf_counter()
Byron Hsu's avatar
Byron Hsu committed
1255
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
1256
            self._prefetch_kvcache(req)
Byron Hsu's avatar
Byron Hsu committed
1257
1258
1259
            self.disagg_prefill_bootstrap_queue.add(
                req, self.model_config.num_key_value_heads
            )
Byron Hsu's avatar
Byron Hsu committed
1260
1261
1262
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.add(req)
        else:
1263
            self._prefetch_kvcache(req)
Byron Hsu's avatar
Byron Hsu committed
1264
1265
            self.waiting_queue.append(req)

1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
    def _prefetch_kvcache(self, req: Req):
        if self.enable_hicache_storage:
            req.init_next_round_input(self.tree_cache)
            last_hash = req.last_host_node.get_last_hash_value()
            matched_len = len(req.prefix_indices) + req.host_hit_length
            # todo, free-form fetching, calculating hash keys on the fly
            if (matched_len > 0 and last_hash is not None) or matched_len == 0:
                new_input_tokens = req.fill_ids[matched_len:]
                self.tree_cache.prefetch_from_storage(
                    req.rid, req.last_host_node, new_input_tokens, last_hash
                )

1278
    def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
1279
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
Byron Hsu's avatar
Byron Hsu committed
1280
1281
1282
            self.disagg_prefill_bootstrap_queue.extend(
                reqs, self.model_config.num_key_value_heads
            )
1283
1284
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            # If this is a decode server, we put the request to the decode pending prealloc queue
1285
            self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
Byron Hsu's avatar
Byron Hsu committed
1286
1287
        else:
            self.waiting_queue.extend(reqs)
1288
1289
1290

    def handle_embedding_request(
        self,
1291
        recv_req: TokenizedEmbeddingReqInput,
1292
1293
1294
1295
1296
1297
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
woodx's avatar
woodx committed
1298
            token_type_ids=recv_req.token_type_ids,
1299
1300
1301
        )
        req.tokenizer = self.tokenizer

1302
1303
        # Handle multimodal inputs
        if recv_req.image_inputs is not None:
Mick's avatar
Mick committed
1304
            image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
1305
1306
1307
1308
1309
1310
1311
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
            req.origin_input_ids = self.pad_input_ids_func(
                req.origin_input_ids, image_inputs
            )
            req.extend_image_inputs(image_inputs)

            if len(req.origin_input_ids) >= self.max_req_input_len:
1312
1313
1314
1315
1316
                req.set_finish_with_abort(
                    error_msg=(
                        "Multimodal prompt is too long after expanding multimodal tokens. "
                        f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                    )
1317
                )
1318
                self._add_request_to_queue(req)
1319
1320
                return

1321
        # Validate prompts length
1322
        error_msg = validate_input_length(
1323
1324
1325
1326
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
1327
        if error_msg:
1328
            self._add_request_to_queue(req)
1329
            return
1330

1331
1332
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
1333
        self._add_request_to_queue(req)
1334

1335
1336
1337
1338
1339
    def self_check_during_idle(self):
        self.check_memory()
        self.check_tree_cache()
        self.new_token_ratio = self.init_new_token_ratio
        self.maybe_sleep_on_idle()
1340

Lianmin Zheng's avatar
Lianmin Zheng committed
1341
    def check_memory(self):
Hanming Lu's avatar
Hanming Lu committed
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
        if self.is_hybrid:
            (
                full_num_used,
                swa_num_used,
                _,
                _,
                full_available_size,
                full_evictable_size,
                swa_available_size,
                swa_evictable_size,
            ) = self._get_swa_token_info()
            memory_leak = full_num_used != 0 or swa_num_used != 0
            token_msg = (
                f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
                f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
            )
tarinkk's avatar
tarinkk committed
1358
        else:
Hanming Lu's avatar
Hanming Lu committed
1359
1360
1361
1362
1363
1364
            _, _, available_size, evictable_size = self._get_token_info()
            protected_size = self.tree_cache.protected_size()
            memory_leak = (available_size + evictable_size) != (
                self.max_total_num_tokens
                if not self.enable_hierarchical_cache
                else self.max_total_num_tokens - protected_size
Lianmin Zheng's avatar
Lianmin Zheng committed
1365
            )
Hanming Lu's avatar
Hanming Lu committed
1366
1367
1368
1369
            token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"

        if memory_leak:
            msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
Lianmin Zheng's avatar
Lianmin Zheng committed
1370
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1371

1372
1373
1374
1375
1376
1377
1378
1379
        if self.disaggregation_mode == DisaggregationMode.DECODE:
            req_total_size = (
                self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
            )
        else:
            req_total_size = self.req_to_token_pool.size

        if len(self.req_to_token_pool.free_slots) != req_total_size:
1380
            msg = (
1381
                "req_to_token_pool memory leak detected!"
1382
1383
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1384
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1385
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1386

1387
1388
        if (
            self.enable_metrics
1389
            and self.current_scheduler_metrics_enabled()
1390
            and time.perf_counter() > self.metrics_collector.last_log_time + 30
1391
1392
        ):
            # During idle time, also collect metrics every 30 seconds.
Hanming Lu's avatar
Hanming Lu committed
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
            if self.is_hybrid:
                (
                    full_num_used,
                    swa_num_used,
                    full_token_usage,
                    swa_token_usage,
                    _,
                    _,
                    _,
                    _,
                ) = self._get_swa_token_info()
                num_used = max(full_num_used, swa_num_used)
                token_usage = max(full_token_usage, swa_token_usage)
            else:
                num_used, token_usage, _, _ = self._get_token_info()
Lianmin Zheng's avatar
Lianmin Zheng committed
1408
            num_running_reqs = len(self.running_batch.reqs)
1409
1410
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
Hanming Lu's avatar
Hanming Lu committed
1411
            self.stats.token_usage = round(token_usage, 2)
1412
1413
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
1414
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1415
            self.metrics_collector.log_stats(self.stats)
1416
        self._publish_kv_events()
1417

Hanming Lu's avatar
Hanming Lu committed
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
    def check_tree_cache(self):
        if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache):
            self.tree_cache.sanity_check()

    def _get_token_info(self):
        available_size = self.token_to_kv_pool_allocator.available_size()
        evictable_size = self.tree_cache.evictable_size()
        num_used = self.max_total_num_tokens - (available_size + evictable_size)
        token_usage = num_used / self.max_total_num_tokens
        return num_used, token_usage, available_size, evictable_size

    def _get_swa_token_info(self):
        full_available_size = self.token_to_kv_pool_allocator.full_available_size()
        full_evictable_size = self.tree_cache.full_evictable_size()
        swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
        swa_evictable_size = self.tree_cache.swa_evictable_size()
        full_num_used = self.full_tokens_per_layer - (
            full_available_size + full_evictable_size
        )
        swa_num_used = self.swa_tokens_per_layer - (
            swa_available_size + swa_evictable_size
        )
        full_token_usage = full_num_used / self.full_tokens_per_layer
        swa_token_usage = swa_num_used / self.swa_tokens_per_layer
        return (
            full_num_used,
            swa_num_used,
            full_token_usage,
            swa_token_usage,
            full_available_size,
            full_evictable_size,
            swa_available_size,
            swa_evictable_size,
        )

1453
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1454
        # Merge the prefill batch into the running batch
1455
1456
1457
1458
1459
1460
1461
1462
        chunked_req_to_exclude = set()
        if self.chunked_req:
            # Move the chunked request out of the batch so that we can merge
            # only finished requests to running_batch.
            chunked_req_to_exclude.add(self.chunked_req)
            self.tree_cache.cache_unfinished_req(self.chunked_req)
            # chunked request keeps its rid but will get a new req_pool_idx
            self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
Lianmin Zheng's avatar
Lianmin Zheng committed
1463
        if self.last_batch and self.last_batch.forward_mode.is_extend():
1464
1465
1466
1467
            if self.last_batch.chunked_req is not None:
                # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
                # We need to discard it.
                chunked_req_to_exclude.add(self.last_batch.chunked_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1468

1469
            # Filter batch
1470
            last_bs = self.last_batch.batch_size()
1471
1472
1473
            self.last_batch.filter_batch(
                chunked_req_to_exclude=list(chunked_req_to_exclude)
            )
1474
            if self.last_batch.batch_size() < last_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1475
                self.running_batch.batch_is_full = False
1476

1477
1478
1479
            # Merge the new batch into the running batch.
            # For prefill-only batch, we can avoid going through decoding step.
            if not self.last_batch.is_empty() and not self.last_batch.is_prefill_only:
Lianmin Zheng's avatar
Lianmin Zheng committed
1480
                if self.running_batch.is_empty():
1481
1482
                    self.running_batch = self.last_batch
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1483
                    # Merge running_batch with prefill batch
1484
                    self.running_batch.merge_batch(self.last_batch)
1485

1486
        new_batch = self.get_new_batch_prefill()
1487

1488
1489
1490
1491
1492
        need_dp_attn_preparation = require_mlp_sync(self.server_args)

        if need_dp_attn_preparation and not self.spec_algorithm.is_none():
            # In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
            # We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
1493
            new_batch = self.prepare_mlp_sync_batch(new_batch)
1494
1495
1496
            need_dp_attn_preparation = new_batch is None

        if new_batch is not None:
1497
1498
1499
1500
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
1501
            if not self.running_batch.is_empty():
1502
                self.running_batch = self.update_running_batch(self.running_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1503
1504
1505
                ret = self.running_batch if not self.running_batch.is_empty() else None
            else:
                ret = None
1506

1507
1508
        # Handle DP attention
        if need_dp_attn_preparation:
1509
1510
1511
1512
1513
            if (
                self.server_args.load_balance_method == "minimum_tokens"
                and self.forward_ct % 40 == 0
            ):
                self.handle_dp_balance_data(ret)
1514
            ret = self.prepare_mlp_sync_batch(ret)
1515
1516

        return ret
1517

1518
1519
1520
1521
1522
1523
    def get_num_allocatable_reqs(self, running_bs):
        res = global_server_args_dict["max_micro_batch_size"] - running_bs
        if self.pp_size > 1:
            res = min(res, self.req_to_token_pool.available_size())
        return res

Lianmin Zheng's avatar
Lianmin Zheng committed
1524
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
1525
        # Check if the grammar is ready in the grammar queue
1526
        if self.grammar_queue:
1527
            self.move_ready_grammar_requests()
1528

Lianmin Zheng's avatar
Lianmin Zheng committed
1529
1530
        # Handle the cases where prefill is not allowed
        if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1531
            self.running_batch.batch_is_full or len(self.waiting_queue) == 0
1532
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1533
1534
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1535
        running_bs = len(self.running_batch.reqs)
1536
        # Ignore the check if self.chunked_req is not None.
1537
1538
1539
1540
1541
        # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
        # as the space for the chunked request has just been released.
        # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
        # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
        if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
Lianmin Zheng's avatar
Lianmin Zheng committed
1542
            self.running_batch.batch_is_full = True
1543
1544
            return None

1545
        if self.enable_hierarchical_cache:
1546
            self.tree_cache.check_hicache_events()
1547

1548
        # Get priority queue
1549
        self.policy.calc_priority(self.waiting_queue)
1550

Lianmin Zheng's avatar
Lianmin Zheng committed
1551
        # Prefill policy
1552
        adder = PrefillAdder(
1553
            self.page_size,
1554
            self.tree_cache,
1555
            self.token_to_kv_pool_allocator,
1556
1557
1558
1559
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1560
            running_bs if self.is_mixed_chunk else 0,
1561
1562
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1563
        if self.chunked_req is not None:
1564
1565
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1566

1567
        if self.enable_lora:
1568
            lora_set = set([req.lora_id for req in self.running_batch.reqs])
Lianmin Zheng's avatar
Lianmin Zheng committed
1569

1570
        # Get requests from the waiting queue to a new prefill batch
1571
        for req in self.waiting_queue:
1572
1573
1574
1575
1576

            if self.enable_lora and not self.tp_worker.can_run_lora_batch(
                lora_set
                | set([req.lora_id for req in adder.can_run_list])
                | set([req.lora_id])
1577
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1578
                self.running_batch.batch_is_full = True
1579
1580
                break

1581
            if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
Lianmin Zheng's avatar
Lianmin Zheng committed
1582
                self.running_batch.batch_is_full = True
1583
                break
1584

Byron Hsu's avatar
Byron Hsu committed
1585
1586
1587
1588
1589
1590
1591
            if self.disaggregation_mode == DisaggregationMode.PREFILL:
                # In prefill mode, prealloc queue and transfer queue can also take memory,
                # so we need to check if the available size for the actual available size.
                if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
                    self.running_batch.batch_is_full = True
                    break

1592
            if self.enable_hicache_storage:
pansicheng's avatar
pansicheng committed
1593
1594
1595
1596
                prefetch_done = self.tree_cache.check_prefetch_progress(req.rid)
                if not prefetch_done:
                    # skip staging requests that are ongoing prefetch
                    continue
1597

1598
1599
            req.init_next_round_input(self.tree_cache)
            res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
1600

1601
1602
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1603
1604
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
Lianmin Zheng's avatar
Lianmin Zheng committed
1605
1606
                        self.running_batch.batch_is_full = len(
                            adder.can_run_list
1607
                        ) > 0 or (not self.running_batch.is_empty())
1608
                    else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1609
                        self.running_batch.batch_is_full = True
1610
1611
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1612
        # Update waiting queue
1613
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1614
1615
        if len(can_run_list) == 0:
            return None
1616
1617
1618
1619

        if self.enable_metrics:
            # only record queue time when enable_metrics is True to avoid overhead
            for req in can_run_list:
1620
                req.queue_time_end = time.perf_counter()
1621

Lianmin Zheng's avatar
Lianmin Zheng committed
1622
1623
1624
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1625

1626
1627
1628
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1629

1630
1631
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1632

1633
        # Print stats
1634
        if self.current_scheduler_metrics_enabled():
1635
            self.log_prefill_stats(adder, can_run_list, running_bs)
1636

Lianmin Zheng's avatar
Lianmin Zheng committed
1637
        # Create a new batch
1638
1639
1640
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
1641
            self.token_to_kv_pool_allocator,
1642
            self.tree_cache,
1643
            self.model_config,
1644
            self.enable_overlap,
1645
            self.spec_algorithm,
1646
            chunked_req=self.chunked_req,
1647
        )
1648
1649
        if self.enable_hierarchical_cache:
            # todo (zhiqiang): disable cuda graph execution if hicache loading triggered
1650
1651
1652
            new_batch.hicache_consumer_index = (
                self.tree_cache.ready_to_load_host_cache()
            )
1653

1654
        new_batch.prepare_for_extend()
1655

Lianmin Zheng's avatar
Lianmin Zheng committed
1656
        # Mixed-style chunked prefill
1657
1658
        if (
            self.is_mixed_chunk
Lianmin Zheng's avatar
Lianmin Zheng committed
1659
            and not self.running_batch.is_empty()
1660
1661
1662
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1663
1664
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1665
                self.running_batch.prepare_for_decode()
1666
1667
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1668
1669
1670
            self.running_batch = ScheduleBatch(
                reqs=[], batch_is_full=self.running_batch.batch_is_full
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1671
1672
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1673
1674
1675

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1676
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1677
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1678
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1679

1680
1681
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1682
1683
            batch.batch_is_full = False
            return batch
1684

Lianmin Zheng's avatar
Lianmin Zheng committed
1685
        # Check if decode out of memory
1686
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1687
            TEST_RETRACT and batch.batch_size() > 10
1688
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1689
1690
            old_ratio = self.new_token_ratio

1691
            retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
1692
            num_retracted_reqs = len(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1693
            self.new_token_ratio = new_token_ratio
1694

Lianmin Zheng's avatar
Lianmin Zheng committed
1695
            logger.info(
1696
                "KV cache pool is full. Retract requests. "
1697
                f"#retracted_reqs: {num_retracted_reqs}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
1698
1699
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
1700

1701
            self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
1702
            self.total_retracted_reqs += num_retracted_reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1703
1704
        else:
            self.new_token_ratio = max(
1705
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1706
1707
1708
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1709
        if batch.batch_size() < initial_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1710
            batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1711
1712

        # Update batch tensors
1713
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1714
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1715

1716
1717
1718
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1719
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1720
1721
        self.forward_ct += 1

1722
1723
        # Whether to run the profiler
        self._profile_batch_predicate(batch)
1724
1725
1726
1727
        if self.forward_sleep_time is not None:
            logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
            time.sleep(self.forward_sleep_time)

1728
        # Run forward
1729
        if self.is_generation:
1730
1731
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
1732
1733
1734
1735
1736

                # update the consumer index of hicache to the running batch
                self.tp_worker.set_hicache_consumer(
                    model_worker_batch.hicache_consumer_index
                )
1737
                if self.pp_group.is_last_rank:
1738
                    logits_output, next_token_ids, can_run_cuda_graph = (
1739
1740
1741
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
                else:
1742
                    pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
1743
1744
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
1745
                bid = model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
1746
            else:
1747
1748
1749
                (
                    logits_output,
                    next_token_ids,
1750
                    bid,
1751
                    num_accepted_tokens,
1752
                    can_run_cuda_graph,
1753
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
1754
1755
1756
                bs = batch.batch_size()
                self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
                self.spec_num_total_forward_ct += bs
1757
                self.num_generated_tokens += num_accepted_tokens
1758
1759
1760

            if self.pp_group.is_last_rank:
                batch.output_ids = next_token_ids
1761

1762
1763
1764
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
1765
            if batch.return_logprob or self.spec_algorithm.is_eagle():
1766
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
1767
1768
1769
            else:
                extend_input_len_per_req = None
            if batch.return_logprob:
1770
1771
1772
1773
1774
1775
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_logprob_start_len_per_req = None

1776
            ret = GenerationBatchResult(
1777
1778
1779
1780
1781
1782
1783
                logits_output=logits_output if self.pp_group.is_last_rank else None,
                pp_hidden_states_proxy_tensors=(
                    pp_hidden_states_proxy_tensors
                    if not self.pp_group.is_last_rank
                    else None
                ),
                next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
1784
1785
                extend_input_len_per_req=extend_input_len_per_req,
                extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1786
                bid=bid,
1787
                can_run_cuda_graph=can_run_cuda_graph,
1788
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1789
1790
1791
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1792
1793
1794
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1795
        return ret
Chayenne's avatar
Chayenne committed
1796

1797
1798
1799
1800
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
1801
        launch_done: Optional[threading.Event] = None,
1802
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1803
        if batch.forward_mode.is_decode():
1804
            self.process_batch_result_decode(batch, result, launch_done)
1805
        elif batch.forward_mode.is_extend():
1806
            self.process_batch_result_prefill(batch, result, launch_done)
1807
1808
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1809
                self.tp_worker.resolve_last_batch_result(launch_done)
1810
                self.set_next_batch_sampling_info_done(batch)
1811
        elif batch.forward_mode.is_dummy_first():
1812
            self.set_next_batch_sampling_info_done(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1813

1814
1815
1816
        self.maybe_send_health_check_signal()

    def maybe_send_health_check_signal(self):
1817
1818
1819
1820
1821
1822
1823
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

1824
1825
    def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
        return self.prepare_mlp_sync_batch_raw(
1826
1827
1828
            local_batch,
            dp_size=self.server_args.dp_size,
            attn_tp_size=self.attn_tp_size,
1829
            tp_group=self.tp_group,
1830
1831
1832
1833
            get_idle_batch=self.get_idle_batch,
            disable_cuda_graph=self.server_args.disable_cuda_graph,
            spec_algorithm=self.spec_algorithm,
            speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1834
            require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
1835
            disable_overlap_schedule=self.server_args.disable_overlap_schedule,
1836
1837
        )

1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
    def handle_dp_balance_data(self, local_batch: ScheduleBatch):
        def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
            """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
            recv_list = self.recv_dp_balance_id_this_term
            assert len(recv_list) <= 511, (
                "The number of requests received this round is too large. "
                "Please increase gather_tensor_size and onfly_info_size."
            )
            # The maximum size of the tensor used for gathering data from all workers.
            gather_tensor_size = 512

            # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
            recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
            recv_tensor[0] = holding_tokens_list
            recv_tensor[1] = len(
                recv_list
            )  # The first element is the length of the list.
            recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
                recv_list, dtype=torch.int32
            )

            if self.tp_rank == 0:
                gathered_list = [
                    torch.zeros(gather_tensor_size, dtype=torch.int32)
                    for _ in range(self.balance_meta.num_workers)
                ]
            else:
                gathered_list = None

            torch.distributed.gather(
                recv_tensor, gathered_list, group=self.tp_cpu_group
            )

            gathered_id_list_per_worker = None
            if self.tp_rank == 0:
                gathered_id_list_per_worker = []
                holding_tokens_list = []
                for tensor in gathered_list:
                    holding_tokens_list.append(tensor[0].item())
                    list_length = tensor[1].item()
                    gathered_id_list_per_worker.append(
                        tensor[2 : list_length + 2].tolist()
                    )

            return gathered_id_list_per_worker, holding_tokens_list

        def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
            meta = self.balance_meta

            with meta.mutex:
                onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
                assert len(new_recv_rid_lists) == len(
                    onfly_list
                ), "num_worker not equal"
                # 1.Check if the rid received by each worker this round is present in onfly.
                #   If it is, remove the corresponding onfly item.
                worker_id = 0
                for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
                    for new_recv_rid in new_recv_rids:
                        assert (
                            new_recv_rid in on_fly_reqs
                        ), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
                        del on_fly_reqs[new_recv_rid]
                    worker_id += 1
                # 2. Atomically write local_tokens and onfly into shm under the mutex
                meta.set_shared_onfly_info(onfly_list)
                meta.set_shared_local_tokens(local_tokens)

        holding_tokens = self.get_load()

        new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
            holding_tokens
        )

        self.recv_dp_balance_id_this_term.clear()
        if self.tp_rank == 0:  # only first worker write info
            write_shared_dp_balance_info(
                new_recv_dp_balance_id_list, holding_token_list
            )

1918
    @staticmethod
1919
    def prepare_mlp_sync_batch_raw(
1920
1921
1922
        local_batch: ScheduleBatch,
        dp_size,
        attn_tp_size: int,
1923
        tp_group,
1924
1925
1926
1927
        get_idle_batch,
        disable_cuda_graph: bool,
        spec_algorithm,
        speculative_num_draft_tokens,
1928
        require_mlp_tp_gather: bool,
1929
        disable_overlap_schedule: bool,
1930
    ):
1931
1932
1933
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
1934
            num_tokens_for_logprob = 0
1935
1936
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
1937
            num_tokens_for_logprob = num_tokens
1938
1939
        else:
            num_tokens = local_batch.extend_num_tokens
1940
            num_tokens_for_logprob = sum(
Lianmin Zheng's avatar
Lianmin Zheng committed
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
                [
                    # We should have at least 1 token for sample in every case.
                    max(extend_len - logprob_start_len, 1)
                    for logprob_start_len, extend_len in zip(
                        local_batch.extend_logprob_start_lens, local_batch.extend_lens
                    )
                ]
            )

        if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
            can_cuda_graph = 1
        else:
            can_cuda_graph = 0

        is_extend_in_batch = (
            local_batch.forward_mode.is_extend() if local_batch else False
        )
1958
1959

        tbo_preparer = TboDPAttentionPreparer()
1960
1961
1962
1963
1964
1965
        if disable_overlap_schedule:
            group = tp_group.device_group
            device = tp_group.device
        else:
            group = tp_group.cpu_group
            device = "cpu"
1966

Lianmin Zheng's avatar
Lianmin Zheng committed
1967
1968
1969
1970
        local_info = torch.tensor(
            [
                num_tokens,
                can_cuda_graph,
1971
                num_tokens_for_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
1972
                is_extend_in_batch,
1973
1974
1975
                *tbo_preparer.prepare_all_gather(
                    local_batch,
                ),
Lianmin Zheng's avatar
Lianmin Zheng committed
1976
1977
            ],
            dtype=torch.int64,
1978
            device=device,
Lianmin Zheng's avatar
Lianmin Zheng committed
1979
1980
        )
        global_info = torch.empty(
1981
            (dp_size, attn_tp_size, 6),
Lianmin Zheng's avatar
Lianmin Zheng committed
1982
            dtype=torch.int64,
1983
            device=device,
Lianmin Zheng's avatar
Lianmin Zheng committed
1984
        )
1985
        torch.distributed.all_gather_into_tensor(
Lianmin Zheng's avatar
Lianmin Zheng committed
1986
1987
            global_info.flatten(),
            local_info,
1988
            group=group,
1989
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1990
1991
1992
1993
        global_num_tokens = global_info[:, 0, 0].tolist()
        can_cuda_graph = min(global_info[:, 0, 1].tolist())
        global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
        is_extend_in_batch = global_info[:, 0, 3].tolist()
1994

1995
1996
1997
1998
        tbo_split_seq_index, global_forward_mode = tbo_preparer.compute_output(
            global_info[:, :, 4:6]
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1999
        if local_batch is None and max(global_num_tokens) > 0:
2000
            local_batch = get_idle_batch()
2001
2002

        if local_batch is not None:
2003
            # TODO: handle the case when moe_dense_tp_size != 1
2004
            if not require_mlp_tp_gather:
2005
2006
2007
2008
2009
2010
2011
                local_batch.global_num_tokens = [num_tokens]
                local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
            else:
                local_batch.global_num_tokens = global_num_tokens
                local_batch.global_num_tokens_for_logprob = (
                    global_num_tokens_for_logprob
                )
2012
            local_batch.is_extend_in_batch = any(is_extend_in_batch)
2013
2014
            local_batch.tbo_split_seq_index = tbo_split_seq_index
            local_batch.global_forward_mode = global_forward_mode
2015

2016
            # Check forward mode for cuda graph
2017
            if not disable_cuda_graph:
Lianmin Zheng's avatar
Lianmin Zheng committed
2018
                local_batch.can_run_dp_cuda_graph = can_cuda_graph
2019

2020
        return local_batch
2021
2022
2023
2024
2025

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
2026
            self.token_to_kv_pool_allocator,
2027
2028
2029
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
2030
            self.spec_algorithm,
2031
2032
2033
2034
        )
        idle_batch.prepare_for_idle()
        return idle_batch

2035
2036
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
2037

2038
        num_ready_reqs = 0
2039
        num_timeout_reqs = 0
2040
2041
        for req in self.grammar_queue:
            try:
2042
2043
2044
                if req.finished():  # It is aborted by AbortReq
                    num_ready_reqs += 1
                    continue
2045
                req.grammar = req.grammar.result(timeout=0.03)
2046
2047
2048
2049
2050
                self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
                if req.grammar is INVALID_GRAMMAR_OBJ:
                    req.set_finish_with_abort(
                        f"Invalid grammar request: {req.grammar_key=}"
                    )
2051
2052
                num_ready_reqs += 1
            except futures._base.TimeoutError:
2053
                req.grammar_wait_ct += 1
2054
2055
                # NOTE(lianmin): this timeout is the waiting time of the above line. It is
                # not the waiting time from it enters the grammar queue.
2056
                if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
2057
                    num_timeout_reqs = 1
2058
2059
                break

2060
        if self.server_args.enable_dp_attention:
2061
2062
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
2063
        else:
2064
2065
2066
2067
2068
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
2069
            tensor = torch.tensor([num_ready_reqs, num_timeout_reqs], dtype=torch.int32)
2070
2071
2072
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
2073
            num_ready_reqs_max, num_timeout_reqs_max = tensor.tolist()
2074

2075
            for i in range(num_ready_reqs, num_ready_reqs_max):
2076
                req = self.grammar_queue[i]
2077
2078
                if req.finished():  # It is aborted by AbortReq
                    continue
2079
                req.grammar = req.grammar.result()
2080
2081
2082
2083
2084
2085
2086
2087
                self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
                if req.grammar is INVALID_GRAMMAR_OBJ:
                    req.set_finish_with_abort(
                        f"Invalid grammar request: {req.grammar_key=}"
                    )
        else:
            num_ready_reqs_max = num_ready_reqs
            num_timeout_reqs_max = num_timeout_reqs
2088

2089
2090
2091
2092
2093
2094
2095
        for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
            req = self.grammar_queue[i]
            req.grammar.cancel()
            error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
            req.set_finish_with_abort(error_msg)
            self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
        num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
2096

2097
        self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
2098
2099
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

2100
2101
2102
2103
2104
2105
2106
    def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
        if batch.next_batch_sampling_info:
            if batch.next_batch_sampling_info.grammars is not None:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
                self.current_stream.synchronize()
            batch.next_batch_sampling_info.sampling_info_done.set()

2107
2108
2109
    def watchdog_thread(self):
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
        self.watchdog_last_forward_ct = 0
2110
        self.watchdog_last_time = time.perf_counter()
2111
2112

        while True:
2113
            current = time.perf_counter()
2114
2115
2116
2117
2118
2119
2120
2121
2122
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if current > self.watchdog_last_time + self.watchdog_timeout:
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)

Lianmin Zheng's avatar
Lianmin Zheng committed
2123
2124
        if not disable_request_logging():
            # Print batch size and memory pool info to check whether there are de-sync issues.
Hanming Lu's avatar
Hanming Lu committed
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
            if self.is_hybrid:
                (
                    _,
                    _,
                    _,
                    _,
                    full_available_size,
                    full_evictable_size,
                    swa_available_size,
                    swa_evictable_size,
                ) = self._get_swa_token_info()
                info_msg = (
                    f"{full_available_size=}, "
                    f"{full_evictable_size=}, "
                    f"{swa_available_size=}, "
                    f"{swa_evictable_size=}, "
                )
            else:
                _, _, available_size, evictable_size = self._get_token_info()
                info_msg = f"{available_size=}, " f"{evictable_size=}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
2145
2146
2147
            logger.error(
                f"{self.cur_batch.batch_size()=}, "
                f"{self.cur_batch.reqs=}, "
Hanming Lu's avatar
Hanming Lu committed
2148
                f"{info_msg}"
Lianmin Zheng's avatar
Lianmin Zheng committed
2149
2150
            )

2151
        pyspy_dump_schedulers()
Lianmin Zheng's avatar
Lianmin Zheng committed
2152
        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
2153
2154
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
2155
2156

        # Wait for some time so that the parent process can print the error.
2157
2158
2159
        time.sleep(5)
        self.parent_process.send_signal(signal.SIGQUIT)

2160
2161
2162
    def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
        success = self.flush_cache()
        return FlushCacheReqOutput(success=success)
2163

2164
    def flush_cache(self):
2165
        """Flush the memory pool and cache."""
2166
2167
2168
2169
2170
        if (
            len(self.waiting_queue) == 0
            and self.running_batch.is_empty()
            and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
        ):
2171
2172
            self.cur_batch = None
            self.last_batch = None
2173
            self.tree_cache.reset()
2174
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
2175
                self.grammar_backend.reset()
2176
            self.req_to_token_pool.clear()
2177
            self.token_to_kv_pool_allocator.clear()
2178
2179
2180

            if not self.spec_algorithm.is_none():
                self.draft_worker.model_runner.req_to_token_pool.clear()
2181
                self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
2182
2183
2184
2185
2186

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
2187
2188
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
2189
2190
2191
2192
2193
2194
2195
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
2196
                f"#running-req: {len(self.running_batch.reqs)}"
2197
2198
2199
2200
            )
            if_success = False
        return if_success

Liangsheng Yin's avatar
Liangsheng Yin committed
2201
2202
    def get_load(self):
        # TODO(lsyin): use dynamically maintained num_waiting_tokens
Hanming Lu's avatar
Hanming Lu committed
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
        if self.is_hybrid:
            load_full = (
                self.full_tokens_per_layer
                - self.token_to_kv_pool_allocator.full_available_size()
                - self.tree_cache.full_evictable_size()
            )
            load_swa = (
                self.swa_tokens_per_layer
                - self.token_to_kv_pool_allocator.swa_available_size()
                - self.tree_cache.swa_evictable_size()
            )
            load = max(load_full, load_swa)
        else:
            load = (
                self.max_total_num_tokens
                - self.token_to_kv_pool_allocator.available_size()
                - self.tree_cache.evictable_size()
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
        load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            load += sum(
                len(req.origin_input_ids)
                for req in self.disagg_prefill_bootstrap_queue.queue
            )
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            load += sum(
                len(req.req.origin_input_ids)
                for req in self.disagg_decode_prealloc_queue.queue
            )

        return load

2235
2236
2237
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
2238
2239
2240
2241
2242
2243
2244
2245
2246
        ret["memory_usage"] = {
            "weight": round(
                self.tp_worker.worker.model_runner.weight_load_mem_usage, 2
            ),
            "kvcache": round(
                self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
            ),
            "token_capacity": int(self.max_total_num_tokens),
        }
2247
2248
2249
2250
2251
2252

        if not _is_cpu:
            ret["memory_usage"]["cuda_graph"] = round(
                self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
            )

2253
2254
2255
2256
2257
2258
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )
        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
Liangsheng Yin's avatar
Liangsheng Yin committed
2259
2260
2261
2262

        ret["load"] = self.get_load()

        return GetInternalStateReqOutput(internal_state=ret)
2263
2264
2265
2266
2267

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
2268
                "max_micro_batch_size",
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
2279
2280
2281
2282
2283
2284
2285
2286
            elif k == "max_micro_batch_size" and (
                v > self.max_running_requests // self.pp_size or v < 1
            ):
                logging.warning(
                    f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
                )
                if_success = False
                break
2287
2288
2289
2290
2291
2292
2293
2294
2295
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
2296
            logger.info(f"Global server args updated! {global_server_args_dict=}")
2297
2298
2299
2300
2301
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
    def handle_rpc_request(self, recv_req: RpcReqInput):
        # Handle RPC requests
        logger.info(
            f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
        )

        success = True
        exec = None
        try:
            func = getattr(self, recv_req.method)
            func(recv_req.parameters)
        except Exception as e:
            success = False
            exec = e
            logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")

        barrier()
        return RpcReqOutput(success, "" if not exec else str(exec))

2321
2322
    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
Lianmin Zheng's avatar
Lianmin Zheng committed
2323
        to_del = []
2324
        for i, req in enumerate(self.waiting_queue):
2325
            if recv_req.abort_all or req.rid.startswith(recv_req.rid):
Lianmin Zheng's avatar
Lianmin Zheng committed
2326
                to_del.append(i)
2327

Lianmin Zheng's avatar
Lianmin Zheng committed
2328
        # Sort in reverse order to avoid index issues when deleting
Lianmin Zheng's avatar
Lianmin Zheng committed
2329
        for i in reversed(to_del):
2330
2331
2332
            # Abort method 1: directly pop from the queue
            # This only works for requests that have not started anything.
            # We still need to send something back to TokenizerManager to clean up the state.
Lianmin Zheng's avatar
Lianmin Zheng committed
2333
            req = self.waiting_queue.pop(i)
Lianmin Zheng's avatar
Lianmin Zheng committed
2334
            self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
2335
            logger.debug(f"Abort queued request. {req.rid=}")
2336

2337
2338
2339
2340
2341
        # Delete the requests in the grammar queue
        for req in self.grammar_queue:
            # Abort method 2: call `set_finish_with_abort`
            # The request will still run one prefill forward pass.
            # In this case, we change the input_ids to be only one token to make this prefill cheap.
2342
            if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2343
                logger.debug(f"Abort grammar queue request. {req.rid=}")
2344
2345
                if req.grammar:
                    req.grammar.cancel()
2346
2347
                req.set_finish_with_abort("Aborted by AbortReq.")

2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
        # Delete requests not in the waiting queue when PD disaggregation is enabled
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            # Abort requests that have not yet been bootstrapped
            for i, req in enumerate(self.disagg_prefill_bootstrap_queue.queue):
                logger.debug(f"Abort bootstrap queue request. {req.rid=}")
                if recv_req.abort_all or req.rid.startswith(recv_req.rid):
                    if hasattr(req.disagg_kv_sender, "abort"):
                        req.disagg_kv_sender.abort()

            # Abort in-flight requests
            for i, req in enumerate(self.disagg_prefill_inflight_queue):
                logger.debug(f"Abort inflight queue request. {req.rid=}")
                if recv_req.abort_all or req.rid.startswith(recv_req.rid):
                    if hasattr(req.disagg_kv_sender, "abort"):
                        req.disagg_kv_sender.abort()

        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            # Abort requests that have not yet finished preallocation
            for i, decode_req in enumerate(self.disagg_decode_prealloc_queue.queue):
                logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
                if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
                    if hasattr(decode_req.kv_receiver, "abort"):
                        decode_req.kv_receiver.abort()

            # Abort requests waiting for kvcache to release tree cache
            for i, decode_req in enumerate(self.disagg_decode_transfer_queue.queue):
                logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
                if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
                    if hasattr(decode_req.kv_receiver, "abort"):
                        decode_req.kv_receiver.abort()

2379
        # Delete requests in the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
2380
2381
2382
2383
2384
2385
        if self.cur_batch is self.running_batch or self.cur_batch is None:
            reqs = self.running_batch.reqs
        else:
            reqs = self.running_batch.reqs + self.cur_batch.reqs

        for req in reqs:
2386
2387
2388
            if not req.finished() and (
                recv_req.abort_all or req.rid.startswith(recv_req.rid)
            ):
2389
2390
2391
                # Abort method 3: set `to_abort=True`
                # The request will still run one decode forward pass.
                # Then we reuse all existing code to clean up the KV cache allocation.
Lianmin Zheng's avatar
Lianmin Zheng committed
2392
2393
                logger.debug(f"Abort running request. {req.rid=}")
                req.to_abort = True
2394

2395
2396
2397
    def _pause_engine(self) -> Tuple[List[Req], int]:
        raise NotImplementedError()

2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
    def load_lora_adapter(
        self, recv_req: LoadLoRAAdapterReqInput
    ) -> LoadLoRAAdapterReqOutput:
        """In-place loading a new lora adapter from disk or huggingface."""

        result = self.tp_worker.load_lora_adapter(recv_req)
        return result

    def unload_lora_adapter(
        self, recv_req: UnloadLoRAAdapterReqInput
    ) -> UnloadLoRAAdapterReqOutput:
        """Unload the lora adapter."""

        result = self.tp_worker.unload_lora_adapter(recv_req)
        return result

2414
2415
2416
2417
2418
2419
2420
    def slow_down(self, recv_req: SlowDownReqInput):
        t = recv_req.forward_sleep_time
        if t is not None and t <= 0:
            t = None
        self.forward_sleep_time = t
        return SlowDownReqOutput()

2421
2422
    def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
        if recv_req == ExpertDistributionReq.START_RECORD:
2423
            get_global_expert_distribution_recorder().start_record()
2424
        elif recv_req == ExpertDistributionReq.STOP_RECORD:
2425
            get_global_expert_distribution_recorder().stop_record()
2426
        elif recv_req == ExpertDistributionReq.DUMP_RECORD:
2427
            get_global_expert_distribution_recorder().dump_record()
2428
        else:
2429
            raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
2430
        return ExpertDistributionReqOutput()
2431

2432
    def open_session(self, recv_req: OpenSessionReqInput):
2433
2434
2435
2436
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
2437
            return OpenSessionReqOutput(session_id, False)
2438
        elif session_id is None:
2439
            logger.warning("session id is None, cannot open.")
2440
            return OpenSessionReqOutput(session_id, False)
2441
2442
2443
2444
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
2445
            return OpenSessionReqOutput(session_id, True)
2446
2447
2448
2449
2450
2451
2452
2453
2454

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

2455
2456
    def get_print_prefix(self):
        prefix = ""
2457
2458
        if self.attn_dp_rank is not None:
            prefix += f" DP{self.attn_dp_rank}"
2459
2460
2461
2462
2463
2464
        if self.server_args.tp_size > 1:
            prefix += f" TP{self.tp_rank}"
        if self.pp_size > 1:
            prefix += f" PP{self.pp_rank}"
        return prefix

2465
2466
    def current_scheduler_metrics_enabled(self):
        return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
2467

2468
2469
2470
    def maybe_sleep_on_idle(self):
        if self.idle_sleeper is not None:
            self.idle_sleeper.maybe_sleep()
2471

2472

2473
2474
2475
2476
2477
2478
2479
class IdleSleeper:
    """
    In setups which have long inactivity periods it is desirable to reduce
    system power consumption when sglang does nothing. This would lead not only
    to power savings, but also to more CPU thermal headroom when a request
    eventually comes. This is important in cases when multiple GPUs are connected
    as each GPU would otherwise pin one thread at 100% CPU usage.
2480

2481
2482
2483
    The simplest solution is to use zmq.Poller on all sockets that may receive
    data that needs handling immediately.
    """
2484

2485
2486
    def __init__(self, sockets):
        self.poller = zmq.Poller()
2487
        self.last_empty_time = time.time()
2488
2489
2490
2491
2492
        for s in sockets:
            self.poller.register(s, zmq.POLLIN)

    def maybe_sleep(self):
        self.poller.poll(1000)
2493
2494
2495
2496
2497
2498
2499
        if (
            global_config.torch_empty_cache_interval > 0
            and time.time() - self.last_empty_time
            > global_config.torch_empty_cache_interval
        ):
            self.last_empty_time = time.time()
            torch.cuda.empty_cache()
2500

2501

2502
2503
def is_health_check_generate_req(recv_req):
    return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
2504

2505
2506
2507

def is_work_request(recv_req):
    return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
2508
2509


2510
2511
2512
2513
2514
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
Cheng Wan's avatar
Cheng Wan committed
2515
    moe_ep_rank: int,
2516
    pp_rank: int,
2517
    dp_rank: Optional[int],
2518
    pipe_writer,
2519
    balance_meta: Optional[DPBalanceMeta] = None,
2520
):
2521
    # Generate the prefix
2522
2523
2524
2525
2526
    prefix = ""
    if dp_rank is not None:
        prefix += f" DP{dp_rank}"
    if server_args.tp_size > 1:
        prefix += f" TP{tp_rank}"
Cheng Wan's avatar
Cheng Wan committed
2527
2528
    if server_args.ep_size > 1:
        prefix += f" EP{moe_ep_rank}"
2529
2530
    if server_args.pp_size > 1:
        prefix += f" PP{pp_rank}"
2531

2532
    # Config the process
2533
    setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
2534
    faulthandler.enable()
2535
    kill_itself_when_parent_died()
2536
    parent_process = psutil.Process().parent()
2537

2538
2539
2540
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
2541

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
2542
    # Configure the logger
2543
    configure_logger(server_args, prefix=prefix)
2544
    suppress_other_loggers()
2545

2546
    # Set cpu affinity to this gpu process
2547
2548
2549
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

2550
    # Create a scheduler and run the event loop
2551
    try:
Cheng Wan's avatar
Cheng Wan committed
2552
        scheduler = Scheduler(
2553
2554
2555
2556
2557
2558
2559
2560
            server_args,
            port_args,
            gpu_id,
            tp_rank,
            moe_ep_rank,
            pp_rank,
            dp_rank,
            dp_balance_meta=balance_meta,
Cheng Wan's avatar
Cheng Wan committed
2561
        )
2562
        pipe_writer.send(
Mick's avatar
Mick committed
2563
2564
2565
2566
2567
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2568
        )
Byron Hsu's avatar
Byron Hsu committed
2569

2570
        disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
Byron Hsu's avatar
Byron Hsu committed
2571
        if disaggregation_mode == DisaggregationMode.NULL:
2572
2573
2574
            if server_args.pp_size > 1:
                scheduler.event_loop_pp()
            elif scheduler.enable_overlap:
Byron Hsu's avatar
Byron Hsu committed
2575
2576
2577
2578
                scheduler.event_loop_overlap()
            else:
                scheduler.event_loop_normal()
        elif disaggregation_mode == DisaggregationMode.PREFILL:
2579
2580
2581
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_prefill()
            else:
2582
2583
2584
2585
                if server_args.pp_size > 1:
                    scheduler.event_loop_pp_disagg_prefill()
                else:
                    scheduler.event_loop_normal_disagg_prefill()
2586

Byron Hsu's avatar
Byron Hsu committed
2587
        elif disaggregation_mode == DisaggregationMode.DECODE:
2588
2589
2590
2591
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_decode()
            else:
                scheduler.event_loop_normal_disagg_decode()
Byron Hsu's avatar
Byron Hsu committed
2592

2593
    except Exception:
2594
2595
2596
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)