scheduler.py 104 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
import time
23
from collections import deque
Lianmin Zheng's avatar
Lianmin Zheng committed
24
from concurrent import futures
25
from dataclasses import dataclass
26
from http import HTTPStatus
27
from types import SimpleNamespace
28
from typing import Dict, List, Optional, Tuple, Union
29

30
import psutil
31
import setproctitle
32
import torch
33
import zmq
34
from torch.distributed import barrier
35

36
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
37
from sglang.srt.configs.model_config import ModelConfig
38
39
40
41
from sglang.srt.constrained.base_grammar_backend import (
    INVALID_GRAMMAR_OBJ,
    create_grammar_backend,
)
Byron Hsu's avatar
Byron Hsu committed
42
43
44
45
46
47
48
49
50
51
52
from sglang.srt.disaggregation.decode import (
    DecodePreallocQueue,
    DecodeTransferQueue,
    SchedulerDisaggregationDecodeMixin,
)
from sglang.srt.disaggregation.prefill import (
    PrefillBootstrapQueue,
    SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
53
    MetadataBuffers,
Byron Hsu's avatar
Byron Hsu committed
54
    ReqToMetadataIdxAllocator,
55
    TransferBackend,
56
    prepare_abort,
Byron Hsu's avatar
Byron Hsu committed
57
)
58
from sglang.srt.distributed import get_pp_group, get_world_group
fzyzcjy's avatar
fzyzcjy committed
59
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
xm:D's avatar
xm:D committed
60
61
62
63
64
from sglang.srt.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
65
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
66
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
67
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
68
69
from sglang.srt.managers.io_struct import (
    AbortReq,
70
    CloseSessionReqInput,
71
    ExpertDistributionReq,
72
    ExpertDistributionReqOutput,
73
74
    FlushCacheReqInput,
    FlushCacheReqOutput,
75
76
    GetInternalStateReq,
    GetInternalStateReqOutput,
77
    GetWeightsByNameReqInput,
78
    HealthCheckOutput,
79
    InitWeightsUpdateGroupReqInput,
80
81
    LoadLoRAAdapterReqInput,
    LoadLoRAAdapterReqOutput,
82
83
    OpenSessionReqInput,
    OpenSessionReqOutput,
84
    ProfileReq,
85
86
    ReleaseMemoryOccupationReqInput,
    ResumeMemoryOccupationReqInput,
87
88
    RpcReqInput,
    RpcReqOutput,
89
90
    SetInternalStateReq,
    SetInternalStateReqOutput,
91
92
    SlowDownReqInput,
    SlowDownReqOutput,
93
94
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
95
96
    UnloadLoRAAdapterReqInput,
    UnloadLoRAAdapterReqOutput,
Chayenne's avatar
Chayenne committed
97
    UpdateWeightFromDiskReqInput,
98
    UpdateWeightsFromDistributedReqInput,
99
    UpdateWeightsFromTensorReqInput,
100
)
101
from sglang.srt.managers.mm_utils import init_embedding_cache
102
103
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
Mick's avatar
Mick committed
104
    MultimodalInputs,
105
106
    Req,
    ScheduleBatch,
107
    global_server_args_dict,
108
)
109
110
111
112
113
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
fzyzcjy's avatar
fzyzcjy committed
114
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
115
116
117
118
from sglang.srt.managers.scheduler_metrics_mixin import (
    RECORD_STEP_TIME,
    SchedulerMetricsMixin,
)
119
120
121
from sglang.srt.managers.scheduler_output_processor_mixin import (
    SchedulerOutputProcessorMixin,
)
122
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
123
from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
124
125
126
from sglang.srt.managers.scheduler_update_weights_mixin import (
    SchedulerUpdateWeightsMixin,
)
127
from sglang.srt.managers.session_controller import Session
128
from sglang.srt.managers.tp_worker import TpModelWorker
129
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
130
from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
tarinkk's avatar
tarinkk committed
131
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
132
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
133
from sglang.srt.mem_cache.radix_cache import RadixCache
Hanming Lu's avatar
Hanming Lu committed
134
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
Lianmin Zheng's avatar
Lianmin Zheng committed
135
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
136
from sglang.srt.reasoning_parser import ReasoningParser
137
from sglang.srt.server_args import PortArgs, ServerArgs
138
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
139
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
140
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
141
from sglang.srt.utils import (
142
    DynamicGradMode,
143
    broadcast_pyobj,
fzyzcjy's avatar
fzyzcjy committed
144
    configure_gc_logger,
145
    configure_logger,
Lianmin Zheng's avatar
Lianmin Zheng committed
146
    disable_request_logging,
147
    get_available_gpu_memory,
148
    get_bool_env_var,
149
    get_zmq_socket,
150
    is_cpu,
Lianmin Zheng's avatar
Lianmin Zheng committed
151
    kill_itself_when_parent_died,
152
    point_to_point_pyobj,
153
    pyspy_dump_schedulers,
154
155
    require_mlp_sync,
    require_mlp_tp_gather,
156
    set_gpu_proc_affinity,
157
158
159
    set_random_seed,
    suppress_other_loggers,
)
160
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
161
162
163

logger = logging.getLogger(__name__)

164
# Test retract decode for debugging purposes
165
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
166
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
167

168
169
_is_cpu = is_cpu()

170

171
172
@dataclass
class GenerationBatchResult:
173
174
175
    logits_output: Optional[LogitsProcessorOutput]
    pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
    next_token_ids: Optional[List[int]]
176
177
    extend_input_len_per_req: List[int]
    extend_logprob_start_len_per_req: List[int]
178
    bid: int
179
    can_run_cuda_graph: bool
180
181
182
183
184
185
186
187


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


Byron Hsu's avatar
Byron Hsu committed
188
189
class Scheduler(
    SchedulerOutputProcessorMixin,
190
191
192
    SchedulerUpdateWeightsMixin,
    SchedulerProfilerMixin,
    SchedulerMetricsMixin,
Byron Hsu's avatar
Byron Hsu committed
193
194
195
    SchedulerDisaggregationDecodeMixin,
    SchedulerDisaggregationPrefillMixin,
):
196
197
198
199
200
201
202
203
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
Cheng Wan's avatar
Cheng Wan committed
204
        moe_ep_rank: int,
205
        pp_rank: int,
206
        dp_rank: Optional[int],
207
        dp_balance_meta: Optional[DPBalanceMeta] = None,
208
209
    ):
        # Parse args
210
        self.server_args = server_args
211
        self.tp_rank = tp_rank
Cheng Wan's avatar
Cheng Wan committed
212
        self.moe_ep_rank = moe_ep_rank
213
        self.pp_rank = pp_rank
214
        self.dp_rank = dp_rank
215
        self.tp_size = server_args.tp_size
Cheng Wan's avatar
Cheng Wan committed
216
        self.moe_ep_size = server_args.ep_size
217
218
        self.pp_size = server_args.pp_size
        self.dp_size = server_args.dp_size
219
        self.schedule_policy = server_args.schedule_policy
220
        self.enable_lora = server_args.enable_lora
221
        self.max_loras_per_batch = server_args.max_loras_per_batch
222
        self.enable_overlap = not server_args.disable_overlap_schedule
223
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
224
        self.enable_metrics = server_args.enable_metrics
225
226
227
        self.enable_metrics_for_all_schedulers = (
            server_args.enable_metrics_for_all_schedulers
        )
228
        self.enable_kv_cache_events = server_args.kv_events_config is not None
229
        self.stream_interval = server_args.stream_interval
230
231
232
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
233
234
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
235
        self.enable_hicache_storage = server_args.hicache_storage_backend is not None
Lianmin Zheng's avatar
Lianmin Zheng committed
236
        self.page_size = server_args.page_size
237

238
        self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
239
240
241
242
243
244
245
246
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

247
248
        # Init inter-process communication
        context = zmq.Context(2)
249
        self.idle_sleeper = None
250

251
        if self.pp_rank == 0 and self.attn_tp_rank == 0:
252
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
253
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
254
            )
255
256
257
258
            self.recv_from_rpc = get_zmq_socket(
                context, zmq.DEALER, port_args.rpc_ipc_name, False
            )

259
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
260
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
261
            )
262
            if server_args.skip_tokenizer_init:
263
                # Directly send to the TokenizerManager
264
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
265
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
266
267
                )
            else:
268
                # Send to the DetokenizerManager
269
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
270
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
271
                )
272

273
274
275
276
277
278
279
            if self.server_args.sleep_on_idle:
                self.idle_sleeper = IdleSleeper(
                    [
                        self.recv_from_tokenizer,
                        self.recv_from_rpc,
                    ]
                )
280
        else:
281
            self.recv_from_tokenizer = None
282
            self.recv_from_rpc = None
283
284
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
285

286
287
288
289
290
        if self.current_scheduler_metrics_enabled():
            self.send_metrics_from_scheduler = get_zmq_socket(
                context, zmq.PUSH, port_args.metrics_ipc_name, False
            )

291
        # Init tokenizer
292
        self.init_tokenizer()
293

294
295
296
297
298
299
300
301
302
        # Set reasoning_parser and think_end_id if --reasoning_parser is enabled
        if self.server_args.reasoning_parser and self.tokenizer:
            reasoning_parser = ReasoningParser(
                model_type=self.server_args.reasoning_parser, stream_reasoning=False
            )
            self.tokenizer.think_end_id = self.tokenizer.encode(
                reasoning_parser.detector.think_end_token, add_special_tokens=False
            )[0]

303
304
305
306
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
307

308
        # Launch a tensor parallel worker
309
        if self.enable_overlap:
310
            TpWorkerClass = TpModelWorkerClient
311
312
        else:
            TpWorkerClass = TpModelWorker
313

314
        self.tp_worker = TpWorkerClass(
315
            server_args=server_args,
316
317
            gpu_id=gpu_id,
            tp_rank=tp_rank,
Cheng Wan's avatar
Cheng Wan committed
318
            moe_ep_rank=moe_ep_rank,
319
            pp_rank=pp_rank,
320
            dp_rank=dp_rank,
321
            nccl_port=port_args.nccl_port,
322
        )
323

324
        # Launch a draft worker for speculative decoding
325
326
327
328
329
330
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
Cheng Wan's avatar
Cheng Wan committed
331
                moe_ep_rank=moe_ep_rank,
332
333
334
335
336
337
338
339
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
        else:
            self.draft_worker = None

340
        # Get token and memory info from the model worker
341
342
343
344
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
345
            self.max_queued_requests,
346
            self.max_req_len,
347
348
            self.max_req_input_len,
            self.random_seed,
349
            self.device,
350
351
352
353
354
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
355
356
357
358
359
360
361
362
        if global_server_args_dict["max_micro_batch_size"] is None:
            global_server_args_dict["max_micro_batch_size"] = max(
                self.max_running_requests // server_args.pp_size, 1
            )

        self.tp_group = self.tp_worker.get_tp_group()
        self.tp_cpu_group = self.tp_group.cpu_group
        self.attn_tp_group = self.tp_worker.get_attention_tp_group()
363
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
364
365
366
        self.pp_group = get_pp_group()
        self.world_group = get_world_group()

367
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
368
        global_server_args_dict.update(worker_global_server_args_dict)
369
        set_random_seed(self.random_seed)
370

371
        # Hybrid memory pool
Hanming Lu's avatar
Hanming Lu committed
372
373
374
375
376
377
378
        self.is_hybrid = self.tp_worker.is_hybrid
        if self.is_hybrid:
            self.sliding_window_size = self.tp_worker.sliding_window_size
            self.full_tokens_per_layer, self.swa_tokens_per_layer = (
                self.tp_worker.get_tokens_per_layer_info()
            )

379
        # Print debug info
380
        if tp_rank == 0:
381
382
383
            avail_mem = get_available_gpu_memory(
                self.device, self.gpu_id, empty_cache=False
            )
384
385
386
387
388
            logger.info(
                f"max_total_num_tokens={self.max_total_num_tokens}, "
                f"chunked_prefill_size={server_args.chunked_prefill_size}, "
                f"max_prefill_tokens={self.max_prefill_tokens}, "
                f"max_running_requests={self.max_running_requests}, "
389
390
                f"context_len={self.model_config.context_len}, "
                f"available_gpu_mem={avail_mem:.2f} GB"
391
            )
392

Lianmin Zheng's avatar
Lianmin Zheng committed
393
        # Init memory pool and cache
394
        self.init_memory_pool_and_cache()
395
396
397

        # Init running status
        self.waiting_queue: List[Req] = []
398
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
399
        self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
400
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
401
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
402
        # The last forward batch
403
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
404
405
        self.forward_ct = 0
        self.forward_ct_decode = 0
406
        self.num_generated_tokens = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
407
        self.last_prefill_tokens = 0
408
409
        self.last_decode_stats_tic = time.perf_counter()
        self.last_prefill_stats_tic = time.perf_counter()
410
        self.return_health_check_ct = 0
411
412
413
414
415
        self.num_retracted_reqs: int = 0
        self.num_paused_reqs: int = 0
        self.kv_transfer_speed_gb_s: float = 0.0
        self.kv_transfer_latency_ms: float = 0.0
        self.sessions: Dict[str, Session] = {}
416
        self.current_stream = torch.get_device_module(self.device).current_stream()
417
418
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
419
        self.forward_sleep_time = None
420

421
422
        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
423
424
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
425
        self.chunked_req = None
426
427
428
429
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
430
        # Init the grammar backend for constrained generation
431
        self.grammar_queue: List[Req] = []
432
        if not server_args.skip_tokenizer_init:
433
            self.grammar_backend = create_grammar_backend(
434
435
436
437
                server_args,
                self.tokenizer,
                self.model_config.vocab_size,
                self.model_config.hf_eos_token_id,
438
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
439
440
        else:
            self.grammar_backend = None
441

442
        # Init schedule policy and new token estimation
443
        self.policy = SchedulePolicy(
Lianmin Zheng's avatar
Lianmin Zheng committed
444
445
446
            self.schedule_policy,
            self.tree_cache,
            self.enable_hierarchical_cache,
447
        )
448
449
450
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
451
452
        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
453
454
            * server_args.schedule_conservativeness,
            1.0,
455
        )
456
457
458
459
460
461
462
463
464
465
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
466
467
468
469
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
470
        self.parent_process = psutil.Process().parent()
471
472

        # Init memory saver, profiler and metric stats
473
474
475
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )
476
        self.offload_tags = set()
477
        self.init_profier()
478

479
        self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
fzyzcjy's avatar
fzyzcjy committed
480
481
482
483
484
485
        self.input_blocker = (
            SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
            if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
            else None
        )

486
        # Init metrics stats
487
        self.init_metrics(tp_rank, pp_rank, dp_rank)
488
        self.init_kv_events(server_args.kv_events_config)
489

490
491
492
493
494
495
496
497
498
        # Init disaggregation
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
        self.init_disaggregation()

        if get_bool_env_var("SGLANG_GC_LOG"):
            configure_gc_logger()

499
500
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
501
502
503
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
504
                (FlushCacheReqInput, self.flush_cache_wrapped),
505
                (AbortReq, self.abort_request),
506
507
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
508
509
510
511
512
513
514
515
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
516
517
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
518
                (SlowDownReqInput, self.slow_down),
519
                (ProfileReq, self.profile),
520
                (GetInternalStateReq, self.get_internal_state),
521
                (SetInternalStateReq, self.set_internal_state),
522
                (RpcReqInput, self.handle_rpc_request),
523
                (ExpertDistributionReq, self.expert_distribution_handle),
524
525
                (LoadLoRAAdapterReqInput, self.load_lora_adapter),
                (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
526
527
528
            ]
        )

529
530
531
532
533
534
535
536
537
        self.balance_meta = dp_balance_meta
        if (
            server_args.enable_dp_attention
            and server_args.load_balance_method == "minimum_tokens"
        ):
            assert dp_balance_meta is not None

        self.recv_dp_balance_id_this_term = []

538
539
    def init_tokenizer(self):
        server_args = self.server_args
Lianmin Zheng's avatar
Lianmin Zheng committed
540

541
        self.model_config = ModelConfig.from_server_args(server_args)
542
        self.is_generation = self.model_config.is_generation
543

544
545
546
547
548
549
550
551
552
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if self.model_config.is_multimodal:
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
553
                    use_fast=not server_args.disable_fast_image_processor,
554
                )
xm:D's avatar
xm:D committed
555
                self.tokenizer = get_tokenizer_from_processor(self.processor)
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )

    def init_memory_pool_and_cache(self):
        server_args = self.server_args

        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            self.tp_worker.get_memory_pool()
        )

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
Hanming Lu's avatar
Hanming Lu committed
575
            if self.is_hybrid:
tarinkk's avatar
tarinkk committed
576
577
578
579
                ChunkCacheClass = SWAChunkCache
            else:
                ChunkCacheClass = ChunkCache
            self.tree_cache = ChunkCacheClass(
580
581
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
582
                page_size=self.page_size,
583
584
            )
        else:
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
            if os.environ.get("SGLANG_EXPERIMENTAL_CPP_RADIX_TREE") == "1":
                # lazy import to avoid JIT overhead
                from sglang.srt.mem_cache.radix_cache_cpp import RadixCacheCpp

                self.tree_cache = RadixCacheCpp(
                    disable=False,
                    use_hicache=self.enable_hierarchical_cache,
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool=self.token_to_kv_pool_allocator,
                    tp_cache_group=self.tp_cpu_group,
                    page_size=self.page_size,
                    hicache_ratio=server_args.hicache_ratio,
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
                    enable_kv_cache_events=self.enable_kv_cache_events,
                )
            elif self.enable_hierarchical_cache:
602
603
604
                self.tree_cache = HiRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
605
606
607
608
609
                    tp_cache_group=(
                        self.attn_tp_cpu_group
                        if self.server_args.enable_dp_attention
                        else self.tp_cpu_group
                    ),
610
                    page_size=self.page_size,
611
                    hicache_ratio=server_args.hicache_ratio,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
612
613
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
614
615
616
617
618
619
                    hicache_io_backend=(
                        "direct"
                        if server_args.attention_backend
                        == "fa3"  # hot fix for incompatibility
                        else server_args.hicache_io_backend
                    ),
620
                    hicache_mem_layout=server_args.hicache_mem_layout,
621
                    hicache_storage_backend=server_args.hicache_storage_backend,
pansicheng's avatar
pansicheng committed
622
                    hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
623
                )
624
625
626
                self.tp_worker.register_hicache_layer_transfer_counter(
                    self.tree_cache.cache_controller.layer_done_counter
                )
Hanming Lu's avatar
Hanming Lu committed
627
628
629
630
631
632
633
634
635
636
637
            elif self.is_hybrid:
                assert (
                    self.server_args.disaggregation_mode == "null"
                ), "Hybrid mode does not support disaggregation yet"
                self.tree_cache = SWARadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                    sliding_window_size=self.sliding_window_size,
                    page_size=self.page_size,
                    disable=server_args.disable_radix_cache,
                )
638

639
640
641
642
            else:
                self.tree_cache = RadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
643
                    page_size=self.page_size,
644
                    disable=server_args.disable_radix_cache,
645
                    enable_kv_cache_events=self.enable_kv_cache_events,
646
647
648
649
650
651
652
653
654
655
656
657
                )

        self.decode_mem_cache_buf_multiplier = (
            1
            if self.spec_algorithm.is_none()
            else (
                server_args.speculative_num_draft_tokens
                + (
                    server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                )
            )
658
        )
659

660
661
662
        embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
        init_embedding_cache(embedding_cache_size * 1024 * 1024)

Byron Hsu's avatar
Byron Hsu committed
663
    def init_disaggregation(self):
664
665
666
667
        self.transfer_backend = TransferBackend(
            self.server_args.disaggregation_transfer_backend
        )

Byron Hsu's avatar
Byron Hsu committed
668
669
670
671
        if (
            self.disaggregation_mode == DisaggregationMode.DECODE
        ):  # *2 for the headroom.
            buffer_size = (self.req_to_token_pool.size) * 2
Byron Hsu's avatar
Byron Hsu committed
672
            self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
Byron Hsu's avatar
Byron Hsu committed
673
674
                buffer_size
            )
675
676
            self.disagg_metadata_buffers = MetadataBuffers(
                buffer_size,
677
678
                hidden_size=self.model_config.hf_text_config.hidden_size,
                dtype=self.model_config.dtype,
679
680
                custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
            )
Byron Hsu's avatar
Byron Hsu committed
681
682
683

            # The decode requests polling kv cache
            self.disagg_decode_transfer_queue = DecodeTransferQueue(
684
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
685
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
686
                tp_rank=self.tp_rank,
687
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
688
689
                scheduler=self,
                tree_cache=self.tree_cache,
Byron Hsu's avatar
Byron Hsu committed
690
691
692
693
694
695
            )

            # The decode requests pending for pre-allocation
            self.disagg_decode_prealloc_queue = DecodePreallocQueue(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Byron Hsu's avatar
Byron Hsu committed
696
697
698
699
700
                draft_token_to_kv_pool=(
                    None
                    if self.draft_worker is None
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
701
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
702
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
703
704
705
                scheduler=self,
                transfer_queue=self.disagg_decode_transfer_queue,
                tree_cache=self.tree_cache,
706
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
707
708
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
709
710
                dp_size=self.server_args.dp_size,
                gpu_id=self.gpu_id,
Byron Hsu's avatar
Byron Hsu committed
711
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
712
713
                max_total_num_tokens=self.max_total_num_tokens,
                prefill_pp_size=self.server_args.disaggregation_prefill_pp,
714
                num_reserved_decode_tokens=self.server_args.num_reserved_decode_tokens,
715
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
716
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
717

Byron Hsu's avatar
Byron Hsu committed
718
719
720
        elif self.disaggregation_mode == DisaggregationMode.PREFILL:
            # *2 for the headroom.
            buffer_size = self.max_running_requests * 2
Byron Hsu's avatar
Byron Hsu committed
721
            self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
Byron Hsu's avatar
Byron Hsu committed
722
723
                buffer_size
            )
724
725
            self.disagg_metadata_buffers = MetadataBuffers(
                buffer_size,
726
727
                hidden_size=self.model_config.hf_text_config.hidden_size,
                dtype=self.model_config.dtype,
728
729
                custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
            )
Byron Hsu's avatar
Byron Hsu committed
730

Liangsheng Yin's avatar
Liangsheng Yin committed
731
            self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
Byron Hsu's avatar
Byron Hsu committed
732
                token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
Byron Hsu's avatar
Byron Hsu committed
733
734
735
736
737
                draft_token_to_kv_pool=(
                    None
                    if self.draft_worker is None
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
738
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
739
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
740
741
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
Byron Hsu's avatar
Byron Hsu committed
742
                gpu_id=self.gpu_id,
Byron Hsu's avatar
Byron Hsu committed
743
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
744
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
745
746
747
                max_total_num_tokens=self.max_total_num_tokens,
                decode_tp_size=self.server_args.disaggregation_decode_tp,
                decode_dp_size=self.server_args.disaggregation_decode_dp,
748
                scheduler=self,
Byron Hsu's avatar
Byron Hsu committed
749
750
751
                pp_rank=self.pp_rank,
                pp_size=self.pp_size,
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
752
753
            )
            # The prefill requests that are in the middle of kv sending
754
            self.disagg_prefill_inflight_queue: List[Req] = []
Byron Hsu's avatar
Byron Hsu committed
755

756
    @DynamicGradMode()
757
    def event_loop_normal(self):
758
        """A normal scheduler loop."""
759
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
760
761
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
762

763
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
764
            self.cur_batch = batch
765
766
767
768

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
769
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
770
                # When the server is idle, do self-check and re-init some states
771
                self.self_check_during_idle()
772
773

            self.last_batch = batch
774

775
    @DynamicGradMode()
Lianmin Zheng's avatar
Lianmin Zheng committed
776
    def event_loop_overlap(self):
777
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
778
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
779
780
781
782
783
784
785

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
786

Lianmin Zheng's avatar
Lianmin Zheng committed
787
            if batch:
788
                batch.launch_done = threading.Event()
Lianmin Zheng's avatar
Lianmin Zheng committed
789
                result = self.run_batch(batch)
790
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
791

792
                if self.last_batch is None:
793
                    # Create a dummy first batch to start the pipeline for overlap schedule.
794
795
796
797
798
799
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
800
                    self.process_batch_result(tmp_batch, None, batch.launch_done)
801

Lianmin Zheng's avatar
Lianmin Zheng committed
802
            if self.last_batch:
803
                # Process the results of the last batch
804
                tmp_batch, tmp_result = self.result_queue.popleft()
805
806
807
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
808
809
810
811
                # NOTE: we should use current launched batch's launch_done event Instead of the last batch's
                self.process_batch_result(
                    tmp_batch, tmp_result, batch.launch_done if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
812
            elif batch is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
813
                # When the server is idle, do self-check and re-init some states
814
                self.self_check_during_idle()
Lianmin Zheng's avatar
Lianmin Zheng committed
815
816
817

            self.last_batch = batch

818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
    @DynamicGradMode()
    def event_loop_pp(self):
        """A non-overlap scheduler loop for pipeline parallelism."""
        mbs = [None] * self.pp_size
        last_mbs = [None] * self.pp_size
        self.running_mbs = [
            ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
        ]
        bids = [None] * self.pp_size
        pp_outputs: Optional[PPProxyTensors] = None
        while True:
            server_is_idle = True
            for mb_id in range(self.pp_size):
                self.running_batch = self.running_mbs[mb_id]
                self.last_batch = last_mbs[mb_id]

                recv_reqs = self.recv_requests()
                self.process_input_requests(recv_reqs)
                mbs[mb_id] = self.get_next_batch_to_run()
                self.running_mbs[mb_id] = self.running_batch

                self.cur_batch = mbs[mb_id]
                if self.cur_batch:
                    server_is_idle = False
                    result = self.run_batch(self.cur_batch)

844
                # (last rank) send the outputs to the next step
845
846
847
848
849
850
                if self.pp_group.is_last_rank:
                    if self.cur_batch:
                        next_token_ids, bids[mb_id] = (
                            result.next_token_ids,
                            result.bid,
                        )
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
                        if self.cur_batch.return_logprob:
                            pp_outputs = PPProxyTensors(
                                {
                                    "next_token_ids": next_token_ids,
                                    "extend_input_len_per_req": result.extend_input_len_per_req,
                                    "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
                                }
                                | (
                                    {
                                        f"logits_output.{k}": v
                                        for k, v in result.logits_output.__dict__.items()
                                    }
                                    if result.logits_output is not None
                                    else {}
                                )
                            )
                        else:
                            pp_outputs = PPProxyTensors(
                                {
                                    "next_token_ids": next_token_ids,
                                }
                            )
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
                        # send the output from the last round to let the next stage worker run post processing
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                # receive outputs and post-process (filter finished reqs) the coming microbatch
                next_mb_id = (mb_id + 1) % self.pp_size
                next_pp_outputs = None
                if mbs[next_mb_id] is not None:
                    next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
                        self.pp_group.recv_tensor_dict(
                            all_gather_group=self.attn_tp_group
                        )
                    )
                    mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
889
890
891
892
893
894
895
896
897
                    logits_output_args = {
                        k[len("logits_output.") :]: v
                        for k, v in next_pp_outputs.tensors.items()
                        if k.startswith("logits_output.")
                    }
                    if len(logits_output_args) > 0:
                        logits_output = LogitsProcessorOutput(**logits_output_args)
                    else:
                        logits_output = None
898
                    output_result = GenerationBatchResult(
899
                        logits_output=logits_output,
900
901
                        pp_hidden_states_proxy_tensors=None,
                        next_token_ids=next_pp_outputs["next_token_ids"],
902
903
904
905
906
907
                        extend_input_len_per_req=next_pp_outputs.tensors.get(
                            "extend_input_len_per_req", None
                        ),
                        extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
                            "extend_logprob_start_len_per_req", None
                        ),
908
                        bid=bids[next_mb_id],
909
                        can_run_cuda_graph=result.can_run_cuda_graph,
910
911
912
913
                    )
                    self.process_batch_result(mbs[next_mb_id], output_result)
                    last_mbs[next_mb_id] = mbs[next_mb_id]

914
                # (not last rank)
915
916
917
                if not self.pp_group.is_last_rank:
                    if self.cur_batch:
                        bids[mb_id] = result.bid
918
919
                    # carry the outputs to the next stage
                    # send the outputs from the last round to let the next stage worker run post processing
920
921
922
923
924
925
926
                    if pp_outputs:
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                    # send out reqs to the next stage
927
                    dp_offset = self.attn_dp_rank * self.attn_tp_size
928
929
930
931
                    if self.attn_tp_rank == 0:
                        point_to_point_pyobj(
                            recv_reqs,
                            self.pp_rank * self.tp_size + dp_offset,
932
                            self.world_group.device_group,
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
                            self.pp_rank * self.tp_size + dp_offset,
                            (self.pp_rank + 1) * self.tp_size + dp_offset,
                        )

                    # send out proxy tensors to the next stage
                    if self.cur_batch:
                        self.pp_group.send_tensor_dict(
                            result.pp_hidden_states_proxy_tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                pp_outputs = next_pp_outputs

            # When the server is idle, self-check and re-init some states
            if server_is_idle:
948
949
                # When the server is idle, do self-check and re-init some states
                self.self_check_during_idle()
950

951
952
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
953
954
955
956
957
958
959
960

        if self.recv_skipper is not None:
            last_forward_mode = (
                self.last_batch.forward_mode if self.last_batch is not None else None
            )
            if not self.recv_skipper.handle(last_forward_mode):
                return []

961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
        if self.pp_rank == 0:
            if self.attn_tp_rank == 0:
                recv_reqs = []

                while True:
                    try:
                        recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_req)

                while True:
                    try:
                        recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_rpc)
            else:
                recv_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
980
        else:
981
            if self.attn_tp_rank == 0:
982
                dp_offset = self.attn_dp_rank * self.attn_tp_size
983
984
985
                recv_reqs = point_to_point_pyobj(
                    [],
                    self.pp_rank * self.tp_size + dp_offset,
986
                    self.world_group.device_group,
987
988
989
990
991
                    (self.pp_rank - 1) * self.tp_size + dp_offset,
                    self.pp_rank * self.tp_size + dp_offset,
                )
            else:
                recv_reqs = None
992

fzyzcjy's avatar
fzyzcjy committed
993
994
995
        if self.input_blocker is not None:
            recv_reqs = self.input_blocker.handle(recv_reqs)

996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                work_reqs = broadcast_pyobj(
                    work_reqs,
1019
                    self.attn_tp_group.rank,
1020
                    self.attn_tp_cpu_group,
1021
                    src=self.attn_tp_group.ranks[0],
1022
1023
1024
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
1025
1026
1027
1028
                    control_reqs,
                    self.tp_group.rank,
                    self.tp_cpu_group,
                    src=self.tp_group.ranks[0],
1029
1030
1031
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
1032
1033
1034
1035
1036
1037
            recv_reqs = broadcast_pyobj(
                recv_reqs,
                self.tp_group.rank,
                self.tp_cpu_group,
                src=self.tp_group.ranks[0],
            )
1038
1039
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
1040
    def process_input_requests(self, recv_reqs: List):
1041
        for recv_req in recv_reqs:
1042
1043
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
1044
1045
1046
                self.chunked_req is not None
                or not self.running_batch.is_empty()
                or len(self.offload_tags) > 0
1047
1048
1049
1050
            ):
                self.return_health_check_ct += 1
                continue

1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
            # If it is a work request, accept or reject the request based on the request queue size.
            if is_work_request(recv_req):
                if len(self.waiting_queue) + 1 > self.max_queued_requests:
                    abort_req = AbortReq(
                        recv_req.rid,
                        finished_reason={
                            "type": "abort",
                            "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
                            "message": "The request queue is full.",
                        },
                    )
                    self.send_to_tokenizer.send_pyobj(abort_req)
                    continue
1064
            output = self._request_dispatcher(recv_req)
1065
            if output is not None:
1066
1067
1068
1069
1070
                if isinstance(output, RpcReqOutput):
                    if self.recv_from_rpc is not None:
                        self.recv_from_rpc.send_pyobj(output)
                else:
                    self.send_to_tokenizer.send_pyobj(output)
1071
1072
1073
1074
1075

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
1076
1077
1078
1079
1080
1081
        if (
            self.server_args.enable_dp_attention
            and self.server_args.load_balance_method == "minimum_tokens"
        ):
            self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)

1082
        # Create a new request
1083
1084
1085
1086
1087
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
1088
1089
1090
1091
1092
1093
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

1094
1095
1096
1097
            if recv_req.bootstrap_port is None:
                # Use default bootstrap port
                recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port

1098
1099
1100
1101
1102
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
1103
1104
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
1105
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
1106
                stream=recv_req.stream,
1107
                lora_id=recv_req.lora_id,
Rin Intachuen's avatar
Rin Intachuen committed
1108
                input_embeds=recv_req.input_embeds,
Lianmin Zheng's avatar
Lianmin Zheng committed
1109
                custom_logit_processor=recv_req.custom_logit_processor,
1110
                return_hidden_states=recv_req.return_hidden_states,
1111
                eos_token_ids=self.model_config.hf_eos_token_id,
1112
                bootstrap_host=recv_req.bootstrap_host,
1113
                bootstrap_port=recv_req.bootstrap_port,
1114
                bootstrap_room=recv_req.bootstrap_room,
1115
                data_parallel_rank=recv_req.data_parallel_rank,
1116
                vocab_size=self.model_config.vocab_size,
1117
1118
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
1119

1120
1121
1122
            if self.disaggregation_mode != DisaggregationMode.NULL:
                # Invalid request for disaggregated mode
                if recv_req.bootstrap_room is None:
1123
                    error_msg = (
1124
1125
1126
                        f"Invalid request: Disaggregated request received without "
                        f"boostrap room id. {req.rid=}"
                    )
1127
1128
                    logger.error(error_msg)
                    prepare_abort(req, error_msg)
1129
1130
1131
                    self.stream_output([req], req.return_logprob)
                    return

1132
1133
1134
1135
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
1136
                req.set_finish_with_abort(
1137
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
1138
                )
1139
                self._add_request_to_queue(req)
1140
1141
                return
        else:
1142
1143
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
1144
            req = session.create_req(recv_req, self.tokenizer)
1145
            if isinstance(req.finished_reason, FINISH_ABORT):
1146
                self._add_request_to_queue(req)
1147
                return
1148

1149
        # Handle multimodal inputs
Mick's avatar
Mick committed
1150
1151
        if recv_req.mm_inputs is not None:
            image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
1152
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
1153
            req.origin_input_ids = self.pad_input_ids_func(
1154
                req.origin_input_ids, image_inputs
1155
            )
1156
            req.extend_image_inputs(image_inputs)
1157

1158
            if len(req.origin_input_ids) >= self.max_req_input_len:
1159
1160
1161
1162
1163
                req.set_finish_with_abort(
                    error_msg=(
                        "Multimodal prompt is too long after expanding multimodal tokens. "
                        f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                    )
1164
                )
1165
                self._add_request_to_queue(req)
1166
1167
                return

1168
        # Validate prompt length
1169
1170
1171
1172
1173
1174
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
1175
            req.set_finish_with_abort(error_msg)
1176
            self._add_request_to_queue(req)
1177
            return
1178

1179
        # Copy more attributes
1180
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
1181
1182
1183
1184
1185
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(req.origin_input_ids) - 1
        else:
            req.logprob_start_len = recv_req.logprob_start_len

1186
        if req.logprob_start_len >= len(req.origin_input_ids):
1187
            error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
1188
            req.logprob_start_len = len(req.origin_input_ids) - 1
1189
            req.set_finish_with_abort(error_msg)
1190
1191
1192
            self._add_request_to_queue(req)
            return

1193
1194
1195
1196
1197
1198
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
1199
            self.max_req_len - len(req.origin_input_ids) - 1,
1200
1201
        )

1202
1203
1204
1205
1206
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
1207
            or req.sampling_params.ebnf is not None
1208
            or req.sampling_params.structural_tag is not None
1209
1210
1211
1212
1213
1214
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
1215
1216
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
1217
1218
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
1219

1220
1221
1222
1223
1224
            value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
            req.grammar = value

            if not cache_hit:
                req.grammar_key = key
1225
                add_to_grammar_queue = True
1226
1227
1228
1229
            else:
                if value is INVALID_GRAMMAR_OBJ:  # We hit a cached invalid grammar.
                    error_msg = f"Invalid grammar request with cache hit: {key=}"
                    req.set_finish_with_abort(error_msg)
1230
1231

        if add_to_grammar_queue:
1232
            req.queue_time_start = time.perf_counter()
1233
1234
            self.grammar_queue.append(req)
        else:
1235
1236
1237
            self._add_request_to_queue(req)

    def _add_request_to_queue(self, req: Req):
1238
        req.queue_time_start = time.perf_counter()
Byron Hsu's avatar
Byron Hsu committed
1239
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
1240
            self._prefetch_kvcache(req)
Byron Hsu's avatar
Byron Hsu committed
1241
1242
1243
            self.disagg_prefill_bootstrap_queue.add(
                req, self.model_config.num_key_value_heads
            )
Byron Hsu's avatar
Byron Hsu committed
1244
1245
1246
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.add(req)
        else:
1247
            self._prefetch_kvcache(req)
Byron Hsu's avatar
Byron Hsu committed
1248
1249
            self.waiting_queue.append(req)

1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
    def _prefetch_kvcache(self, req: Req):
        if self.enable_hicache_storage:
            req.init_next_round_input(self.tree_cache)
            last_hash = req.last_host_node.get_last_hash_value()
            matched_len = len(req.prefix_indices) + req.host_hit_length
            # todo, free-form fetching, calculating hash keys on the fly
            if (matched_len > 0 and last_hash is not None) or matched_len == 0:
                new_input_tokens = req.fill_ids[matched_len:]
                self.tree_cache.prefetch_from_storage(
                    req.rid, req.last_host_node, new_input_tokens, last_hash
                )

1262
    def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
1263
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
Byron Hsu's avatar
Byron Hsu committed
1264
1265
1266
            self.disagg_prefill_bootstrap_queue.extend(
                reqs, self.model_config.num_key_value_heads
            )
1267
1268
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            # If this is a decode server, we put the request to the decode pending prealloc queue
1269
            self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
Byron Hsu's avatar
Byron Hsu committed
1270
1271
        else:
            self.waiting_queue.extend(reqs)
1272
1273
1274

    def handle_embedding_request(
        self,
1275
        recv_req: TokenizedEmbeddingReqInput,
1276
1277
1278
1279
1280
1281
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
woodx's avatar
woodx committed
1282
            token_type_ids=recv_req.token_type_ids,
1283
1284
1285
        )
        req.tokenizer = self.tokenizer

1286
1287
        # Handle multimodal inputs
        if recv_req.image_inputs is not None:
Mick's avatar
Mick committed
1288
            image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
1289
1290
1291
1292
1293
1294
1295
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
            req.origin_input_ids = self.pad_input_ids_func(
                req.origin_input_ids, image_inputs
            )
            req.extend_image_inputs(image_inputs)

            if len(req.origin_input_ids) >= self.max_req_input_len:
1296
1297
1298
1299
1300
                req.set_finish_with_abort(
                    error_msg=(
                        "Multimodal prompt is too long after expanding multimodal tokens. "
                        f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                    )
1301
                )
1302
                self._add_request_to_queue(req)
1303
1304
                return

1305
        # Validate prompts length
1306
        error_msg = validate_input_length(
1307
1308
1309
1310
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
1311
        if error_msg:
1312
            self._add_request_to_queue(req)
1313
            return
1314

1315
1316
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
1317
        self._add_request_to_queue(req)
1318

1319
1320
1321
1322
1323
    def self_check_during_idle(self):
        self.check_memory()
        self.check_tree_cache()
        self.new_token_ratio = self.init_new_token_ratio
        self.maybe_sleep_on_idle()
1324

Lianmin Zheng's avatar
Lianmin Zheng committed
1325
    def check_memory(self):
Hanming Lu's avatar
Hanming Lu committed
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
        if self.is_hybrid:
            (
                full_num_used,
                swa_num_used,
                _,
                _,
                full_available_size,
                full_evictable_size,
                swa_available_size,
                swa_evictable_size,
            ) = self._get_swa_token_info()
            memory_leak = full_num_used != 0 or swa_num_used != 0
            token_msg = (
                f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
                f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
            )
tarinkk's avatar
tarinkk committed
1342
        else:
Hanming Lu's avatar
Hanming Lu committed
1343
1344
1345
1346
1347
1348
            _, _, available_size, evictable_size = self._get_token_info()
            protected_size = self.tree_cache.protected_size()
            memory_leak = (available_size + evictable_size) != (
                self.max_total_num_tokens
                if not self.enable_hierarchical_cache
                else self.max_total_num_tokens - protected_size
Lianmin Zheng's avatar
Lianmin Zheng committed
1349
            )
Hanming Lu's avatar
Hanming Lu committed
1350
1351
1352
1353
            token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"

        if memory_leak:
            msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
Lianmin Zheng's avatar
Lianmin Zheng committed
1354
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1355

1356
1357
1358
1359
1360
1361
1362
1363
        if self.disaggregation_mode == DisaggregationMode.DECODE:
            req_total_size = (
                self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
            )
        else:
            req_total_size = self.req_to_token_pool.size

        if len(self.req_to_token_pool.free_slots) != req_total_size:
1364
            msg = (
1365
                "req_to_token_pool memory leak detected!"
1366
1367
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1368
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1369
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1370

1371
1372
        if (
            self.enable_metrics
1373
            and self.current_scheduler_metrics_enabled()
1374
            and time.perf_counter() > self.metrics_collector.last_log_time + 30
1375
1376
        ):
            # During idle time, also collect metrics every 30 seconds.
Hanming Lu's avatar
Hanming Lu committed
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
            if self.is_hybrid:
                (
                    full_num_used,
                    swa_num_used,
                    full_token_usage,
                    swa_token_usage,
                    _,
                    _,
                    _,
                    _,
                ) = self._get_swa_token_info()
                num_used = max(full_num_used, swa_num_used)
                token_usage = max(full_token_usage, swa_token_usage)
            else:
                num_used, token_usage, _, _ = self._get_token_info()
Lianmin Zheng's avatar
Lianmin Zheng committed
1392
            num_running_reqs = len(self.running_batch.reqs)
1393
1394
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
Hanming Lu's avatar
Hanming Lu committed
1395
            self.stats.token_usage = round(token_usage, 2)
1396
1397
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
1398
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1399
            self.metrics_collector.log_stats(self.stats)
1400
        self._publish_kv_events()
1401

Hanming Lu's avatar
Hanming Lu committed
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
    def check_tree_cache(self):
        if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache):
            self.tree_cache.sanity_check()

    def _get_token_info(self):
        available_size = self.token_to_kv_pool_allocator.available_size()
        evictable_size = self.tree_cache.evictable_size()
        num_used = self.max_total_num_tokens - (available_size + evictable_size)
        token_usage = num_used / self.max_total_num_tokens
        return num_used, token_usage, available_size, evictable_size

    def _get_swa_token_info(self):
        full_available_size = self.token_to_kv_pool_allocator.full_available_size()
        full_evictable_size = self.tree_cache.full_evictable_size()
        swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
        swa_evictable_size = self.tree_cache.swa_evictable_size()
        full_num_used = self.full_tokens_per_layer - (
            full_available_size + full_evictable_size
        )
        swa_num_used = self.swa_tokens_per_layer - (
            swa_available_size + swa_evictable_size
        )
        full_token_usage = full_num_used / self.full_tokens_per_layer
        swa_token_usage = swa_num_used / self.swa_tokens_per_layer
        return (
            full_num_used,
            swa_num_used,
            full_token_usage,
            swa_token_usage,
            full_available_size,
            full_evictable_size,
            swa_available_size,
            swa_evictable_size,
        )

1437
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1438
        # Merge the prefill batch into the running batch
1439
1440
1441
1442
1443
1444
1445
1446
        chunked_req_to_exclude = set()
        if self.chunked_req:
            # Move the chunked request out of the batch so that we can merge
            # only finished requests to running_batch.
            chunked_req_to_exclude.add(self.chunked_req)
            self.tree_cache.cache_unfinished_req(self.chunked_req)
            # chunked request keeps its rid but will get a new req_pool_idx
            self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
Lianmin Zheng's avatar
Lianmin Zheng committed
1447
        if self.last_batch and self.last_batch.forward_mode.is_extend():
1448
1449
1450
1451
            if self.last_batch.chunked_req is not None:
                # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
                # We need to discard it.
                chunked_req_to_exclude.add(self.last_batch.chunked_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1452

1453
            # Filter batch
1454
            last_bs = self.last_batch.batch_size()
1455
1456
1457
            self.last_batch.filter_batch(
                chunked_req_to_exclude=list(chunked_req_to_exclude)
            )
1458
            if self.last_batch.batch_size() < last_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1459
                self.running_batch.batch_is_full = False
1460

1461
            # Merge the new batch into the running batch
1462
            if not self.last_batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1463
                if self.running_batch.is_empty():
1464
1465
                    self.running_batch = self.last_batch
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1466
                    # Merge running_batch with prefill batch
1467
                    self.running_batch.merge_batch(self.last_batch)
1468

1469
        new_batch = self.get_new_batch_prefill()
1470

1471
1472
1473
1474
1475
        need_dp_attn_preparation = require_mlp_sync(self.server_args)

        if need_dp_attn_preparation and not self.spec_algorithm.is_none():
            # In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
            # We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
1476
            new_batch = self.prepare_mlp_sync_batch(new_batch)
1477
1478
1479
            need_dp_attn_preparation = new_batch is None

        if new_batch is not None:
1480
1481
1482
1483
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
1484
            if not self.running_batch.is_empty():
1485
                self.running_batch = self.update_running_batch(self.running_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1486
1487
1488
                ret = self.running_batch if not self.running_batch.is_empty() else None
            else:
                ret = None
1489

1490
1491
        # Handle DP attention
        if need_dp_attn_preparation:
1492
1493
1494
1495
1496
            if (
                self.server_args.load_balance_method == "minimum_tokens"
                and self.forward_ct % 40 == 0
            ):
                self.handle_dp_balance_data(ret)
1497
            ret = self.prepare_mlp_sync_batch(ret)
1498
1499

        return ret
1500

1501
1502
1503
1504
1505
1506
    def get_num_allocatable_reqs(self, running_bs):
        res = global_server_args_dict["max_micro_batch_size"] - running_bs
        if self.pp_size > 1:
            res = min(res, self.req_to_token_pool.available_size())
        return res

Lianmin Zheng's avatar
Lianmin Zheng committed
1507
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
1508
        # Check if the grammar is ready in the grammar queue
1509
        if self.grammar_queue:
1510
            self.move_ready_grammar_requests()
1511

Lianmin Zheng's avatar
Lianmin Zheng committed
1512
1513
        # Handle the cases where prefill is not allowed
        if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1514
            self.running_batch.batch_is_full or len(self.waiting_queue) == 0
1515
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1516
1517
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1518
        running_bs = len(self.running_batch.reqs)
1519
        # Ignore the check if self.chunked_req is not None.
1520
1521
1522
1523
1524
        # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
        # as the space for the chunked request has just been released.
        # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
        # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
        if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
Lianmin Zheng's avatar
Lianmin Zheng committed
1525
            self.running_batch.batch_is_full = True
1526
1527
            return None

1528
        if self.enable_hierarchical_cache:
1529
            self.tree_cache.check_hicache_events()
1530

1531
        # Get priority queue
1532
        self.policy.calc_priority(self.waiting_queue)
1533

Lianmin Zheng's avatar
Lianmin Zheng committed
1534
        # Prefill policy
1535
        adder = PrefillAdder(
1536
            self.page_size,
1537
            self.tree_cache,
1538
            self.token_to_kv_pool_allocator,
1539
1540
1541
1542
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1543
            running_bs if self.is_mixed_chunk else 0,
1544
1545
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1546
        if self.chunked_req is not None:
1547
1548
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1549

1550
        if self.enable_lora:
1551
            lora_set = set([req.lora_id for req in self.running_batch.reqs])
Lianmin Zheng's avatar
Lianmin Zheng committed
1552

1553
        # Get requests from the waiting queue to a new prefill batch
1554
        for req in self.waiting_queue:
1555
1556
1557
1558
1559

            if self.enable_lora and not self.tp_worker.can_run_lora_batch(
                lora_set
                | set([req.lora_id for req in adder.can_run_list])
                | set([req.lora_id])
1560
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1561
                self.running_batch.batch_is_full = True
1562
1563
                break

1564
            if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
Lianmin Zheng's avatar
Lianmin Zheng committed
1565
                self.running_batch.batch_is_full = True
1566
                break
1567

Byron Hsu's avatar
Byron Hsu committed
1568
1569
1570
1571
1572
1573
1574
            if self.disaggregation_mode == DisaggregationMode.PREFILL:
                # In prefill mode, prealloc queue and transfer queue can also take memory,
                # so we need to check if the available size for the actual available size.
                if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
                    self.running_batch.batch_is_full = True
                    break

1575
            if self.enable_hicache_storage:
pansicheng's avatar
pansicheng committed
1576
1577
1578
1579
                prefetch_done = self.tree_cache.check_prefetch_progress(req.rid)
                if not prefetch_done:
                    # skip staging requests that are ongoing prefetch
                    continue
1580

1581
1582
            req.init_next_round_input(self.tree_cache)
            res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
1583

1584
1585
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1586
1587
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
Lianmin Zheng's avatar
Lianmin Zheng committed
1588
1589
                        self.running_batch.batch_is_full = len(
                            adder.can_run_list
1590
                        ) > 0 or (not self.running_batch.is_empty())
1591
                    else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1592
                        self.running_batch.batch_is_full = True
1593
1594
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1595
        # Update waiting queue
1596
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1597
1598
        if len(can_run_list) == 0:
            return None
1599
1600
1601
1602

        if self.enable_metrics:
            # only record queue time when enable_metrics is True to avoid overhead
            for req in can_run_list:
1603
                req.queue_time_end = time.perf_counter()
1604

Lianmin Zheng's avatar
Lianmin Zheng committed
1605
1606
1607
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1608

1609
1610
1611
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1612

1613
1614
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1615

1616
        # Print stats
1617
        if self.current_scheduler_metrics_enabled():
1618
            self.log_prefill_stats(adder, can_run_list, running_bs)
1619

Lianmin Zheng's avatar
Lianmin Zheng committed
1620
        # Create a new batch
1621
1622
1623
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
1624
            self.token_to_kv_pool_allocator,
1625
            self.tree_cache,
1626
            self.model_config,
1627
            self.enable_overlap,
1628
            self.spec_algorithm,
1629
            self.server_args.enable_custom_logit_processor,
1630
            chunked_req=self.chunked_req,
1631
        )
1632
1633
        if self.enable_hierarchical_cache:
            # todo (zhiqiang): disable cuda graph execution if hicache loading triggered
1634
1635
1636
            new_batch.hicache_consumer_index = (
                self.tree_cache.ready_to_load_host_cache()
            )
1637

1638
        new_batch.prepare_for_extend()
1639

Lianmin Zheng's avatar
Lianmin Zheng committed
1640
        # Mixed-style chunked prefill
1641
1642
        if (
            self.is_mixed_chunk
Lianmin Zheng's avatar
Lianmin Zheng committed
1643
            and not self.running_batch.is_empty()
1644
1645
1646
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1647
1648
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1649
                self.running_batch.prepare_for_decode()
1650
1651
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1652
1653
1654
            self.running_batch = ScheduleBatch(
                reqs=[], batch_is_full=self.running_batch.batch_is_full
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1655
1656
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1657
1658
1659

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1660
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1661
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1662
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1663

1664
1665
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1666
1667
            batch.batch_is_full = False
            return batch
1668

Lianmin Zheng's avatar
Lianmin Zheng committed
1669
        # Check if decode out of memory
1670
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1671
            TEST_RETRACT and batch.batch_size() > 10
1672
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1673
1674
            old_ratio = self.new_token_ratio

1675
            retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
1676
            num_retracted_reqs = len(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1677
            self.new_token_ratio = new_token_ratio
1678

Lianmin Zheng's avatar
Lianmin Zheng committed
1679
            logger.info(
1680
                "KV cache pool is full. Retract requests. "
1681
                f"#retracted_reqs: {num_retracted_reqs}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
1682
1683
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
1684

1685
            self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
1686
            self.total_retracted_reqs += num_retracted_reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1687
1688
        else:
            self.new_token_ratio = max(
1689
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1690
1691
1692
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1693
        if batch.batch_size() < initial_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1694
            batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1695
1696

        # Update batch tensors
1697
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1698
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1699

1700
1701
1702
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1703
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1704
1705
        self.forward_ct += 1

1706
1707
        # Whether to run the profiler
        self._profile_batch_predicate(batch)
1708
1709
1710
1711
        if self.forward_sleep_time is not None:
            logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
            time.sleep(self.forward_sleep_time)

1712
        # Run forward
1713
        if self.is_generation:
1714
1715
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
1716
1717
1718
1719
1720

                # update the consumer index of hicache to the running batch
                self.tp_worker.set_hicache_consumer(
                    model_worker_batch.hicache_consumer_index
                )
1721
                if self.pp_group.is_last_rank:
1722
                    logits_output, next_token_ids, can_run_cuda_graph = (
1723
1724
1725
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
                else:
1726
                    pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
1727
1728
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
1729
                bid = model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
1730
            else:
1731
1732
1733
                (
                    logits_output,
                    next_token_ids,
1734
                    bid,
1735
                    num_accepted_tokens,
1736
                    can_run_cuda_graph,
1737
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
1738
1739
1740
                bs = batch.batch_size()
                self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
                self.spec_num_total_forward_ct += bs
1741
                self.num_generated_tokens += num_accepted_tokens
1742
1743
1744

            if self.pp_group.is_last_rank:
                batch.output_ids = next_token_ids
1745

1746
1747
1748
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
1749
            if batch.return_logprob or self.spec_algorithm.is_eagle():
1750
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
1751
1752
1753
            else:
                extend_input_len_per_req = None
            if batch.return_logprob:
1754
1755
1756
1757
1758
1759
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_logprob_start_len_per_req = None

1760
            ret = GenerationBatchResult(
1761
1762
1763
1764
1765
1766
1767
                logits_output=logits_output if self.pp_group.is_last_rank else None,
                pp_hidden_states_proxy_tensors=(
                    pp_hidden_states_proxy_tensors
                    if not self.pp_group.is_last_rank
                    else None
                ),
                next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
1768
1769
                extend_input_len_per_req=extend_input_len_per_req,
                extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1770
                bid=bid,
1771
                can_run_cuda_graph=can_run_cuda_graph,
1772
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1773
1774
1775
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1776
1777
1778
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1779
        return ret
Chayenne's avatar
Chayenne committed
1780

1781
1782
1783
1784
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
1785
        launch_done: Optional[threading.Event] = None,
1786
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1787
        if batch.forward_mode.is_decode():
1788
            self.process_batch_result_decode(batch, result, launch_done)
1789
        elif batch.forward_mode.is_extend():
1790
            self.process_batch_result_prefill(batch, result, launch_done)
1791
1792
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1793
                self.tp_worker.resolve_last_batch_result(launch_done)
1794
                self.set_next_batch_sampling_info_done(batch)
1795
        elif batch.forward_mode.is_dummy_first():
1796
            self.set_next_batch_sampling_info_done(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1797

1798
1799
1800
        self.maybe_send_health_check_signal()

    def maybe_send_health_check_signal(self):
1801
1802
1803
1804
1805
1806
1807
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

1808
1809
    def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
        return self.prepare_mlp_sync_batch_raw(
1810
1811
1812
            local_batch,
            dp_size=self.server_args.dp_size,
            attn_tp_size=self.attn_tp_size,
1813
            tp_group=self.tp_group,
1814
1815
1816
1817
            get_idle_batch=self.get_idle_batch,
            disable_cuda_graph=self.server_args.disable_cuda_graph,
            spec_algorithm=self.spec_algorithm,
            speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1818
            enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
1819
1820
1821
1822
            enable_deepep_moe=MoeA2ABackend(
                self.server_args.moe_a2a_backend
            ).is_deepep(),
            deepep_mode=DeepEPMode(self.server_args.deepep_mode),
1823
            require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
1824
            disable_overlap_schedule=self.server_args.disable_overlap_schedule,
1825
1826
        )

1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
    def handle_dp_balance_data(self, local_batch: ScheduleBatch):
        def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
            """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
            recv_list = self.recv_dp_balance_id_this_term
            assert len(recv_list) <= 511, (
                "The number of requests received this round is too large. "
                "Please increase gather_tensor_size and onfly_info_size."
            )
            # The maximum size of the tensor used for gathering data from all workers.
            gather_tensor_size = 512

            # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
            recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
            recv_tensor[0] = holding_tokens_list
            recv_tensor[1] = len(
                recv_list
            )  # The first element is the length of the list.
            recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
                recv_list, dtype=torch.int32
            )

            if self.tp_rank == 0:
                gathered_list = [
                    torch.zeros(gather_tensor_size, dtype=torch.int32)
                    for _ in range(self.balance_meta.num_workers)
                ]
            else:
                gathered_list = None

            torch.distributed.gather(
                recv_tensor, gathered_list, group=self.tp_cpu_group
            )

            gathered_id_list_per_worker = None
            if self.tp_rank == 0:
                gathered_id_list_per_worker = []
                holding_tokens_list = []
                for tensor in gathered_list:
                    holding_tokens_list.append(tensor[0].item())
                    list_length = tensor[1].item()
                    gathered_id_list_per_worker.append(
                        tensor[2 : list_length + 2].tolist()
                    )

            return gathered_id_list_per_worker, holding_tokens_list

        def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
            meta = self.balance_meta

            with meta.mutex:
                onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
                assert len(new_recv_rid_lists) == len(
                    onfly_list
                ), "num_worker not equal"
                # 1.Check if the rid received by each worker this round is present in onfly.
                #   If it is, remove the corresponding onfly item.
                worker_id = 0
                for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
                    for new_recv_rid in new_recv_rids:
                        assert (
                            new_recv_rid in on_fly_reqs
                        ), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
                        del on_fly_reqs[new_recv_rid]
                    worker_id += 1
                # 2. Atomically write local_tokens and onfly into shm under the mutex
                meta.set_shared_onfly_info(onfly_list)
                meta.set_shared_local_tokens(local_tokens)

        holding_tokens = self.get_load()

        new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
            holding_tokens
        )

        self.recv_dp_balance_id_this_term.clear()
        if self.tp_rank == 0:  # only first worker write info
            write_shared_dp_balance_info(
                new_recv_dp_balance_id_list, holding_token_list
            )

1907
    @staticmethod
1908
    def prepare_mlp_sync_batch_raw(
1909
1910
1911
        local_batch: ScheduleBatch,
        dp_size,
        attn_tp_size: int,
1912
        tp_group,
1913
1914
1915
1916
        get_idle_batch,
        disable_cuda_graph: bool,
        spec_algorithm,
        speculative_num_draft_tokens,
1917
1918
1919
        enable_two_batch_overlap: bool,
        enable_deepep_moe: bool,
        deepep_mode: DeepEPMode,
1920
        require_mlp_tp_gather: bool,
1921
        disable_overlap_schedule: bool,
1922
    ):
1923
1924
1925
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
1926
            num_tokens_for_logprob = 0
1927
1928
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
1929
            num_tokens_for_logprob = num_tokens
1930
1931
        else:
            num_tokens = local_batch.extend_num_tokens
1932
            num_tokens_for_logprob = sum(
Lianmin Zheng's avatar
Lianmin Zheng committed
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
                [
                    # We should have at least 1 token for sample in every case.
                    max(extend_len - logprob_start_len, 1)
                    for logprob_start_len, extend_len in zip(
                        local_batch.extend_logprob_start_lens, local_batch.extend_lens
                    )
                ]
            )

        if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
            can_cuda_graph = 1
        else:
            can_cuda_graph = 0

        is_extend_in_batch = (
            local_batch.forward_mode.is_extend() if local_batch else False
        )
1950
1951

        tbo_preparer = TboDPAttentionPreparer()
1952
1953
1954
1955
1956
1957
        if disable_overlap_schedule:
            group = tp_group.device_group
            device = tp_group.device
        else:
            group = tp_group.cpu_group
            device = "cpu"
1958

Lianmin Zheng's avatar
Lianmin Zheng committed
1959
1960
1961
1962
        local_info = torch.tensor(
            [
                num_tokens,
                can_cuda_graph,
1963
                num_tokens_for_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
1964
                is_extend_in_batch,
1965
1966
1967
1968
1969
1970
                *tbo_preparer.prepare_all_gather(
                    local_batch,
                    deepep_mode,
                    enable_deepep_moe,
                    enable_two_batch_overlap,
                ),
Lianmin Zheng's avatar
Lianmin Zheng committed
1971
1972
            ],
            dtype=torch.int64,
1973
            device=device,
Lianmin Zheng's avatar
Lianmin Zheng committed
1974
1975
        )
        global_info = torch.empty(
1976
            (dp_size, attn_tp_size, 6),
Lianmin Zheng's avatar
Lianmin Zheng committed
1977
            dtype=torch.int64,
1978
            device=device,
Lianmin Zheng's avatar
Lianmin Zheng committed
1979
        )
1980
        torch.distributed.all_gather_into_tensor(
Lianmin Zheng's avatar
Lianmin Zheng committed
1981
1982
            global_info.flatten(),
            local_info,
1983
            group=group,
1984
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1985
1986
1987
1988
        global_num_tokens = global_info[:, 0, 0].tolist()
        can_cuda_graph = min(global_info[:, 0, 1].tolist())
        global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
        is_extend_in_batch = global_info[:, 0, 3].tolist()
1989

1990
1991
1992
1993
        tbo_split_seq_index, global_forward_mode = tbo_preparer.compute_output(
            global_info[:, :, 4:6]
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1994
        if local_batch is None and max(global_num_tokens) > 0:
1995
            local_batch = get_idle_batch()
1996
1997

        if local_batch is not None:
1998
            # TODO: handle the case when moe_dense_tp_size != 1
1999
            if not require_mlp_tp_gather:
2000
2001
2002
2003
2004
2005
2006
                local_batch.global_num_tokens = [num_tokens]
                local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
            else:
                local_batch.global_num_tokens = global_num_tokens
                local_batch.global_num_tokens_for_logprob = (
                    global_num_tokens_for_logprob
                )
2007
            local_batch.is_extend_in_batch = any(is_extend_in_batch)
2008
2009
            local_batch.tbo_split_seq_index = tbo_split_seq_index
            local_batch.global_forward_mode = global_forward_mode
2010

2011
            # Check forward mode for cuda graph
2012
            if not disable_cuda_graph:
Lianmin Zheng's avatar
Lianmin Zheng committed
2013
                local_batch.can_run_dp_cuda_graph = can_cuda_graph
2014

2015
        return local_batch
2016
2017
2018
2019
2020

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
2021
            self.token_to_kv_pool_allocator,
2022
2023
2024
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
2025
            self.spec_algorithm,
2026
            self.server_args.enable_custom_logit_processor,
2027
2028
2029
2030
        )
        idle_batch.prepare_for_idle()
        return idle_batch

2031
2032
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
2033

2034
        num_ready_reqs = 0
2035
        num_timeout_reqs = 0
2036
2037
        for req in self.grammar_queue:
            try:
2038
2039
2040
                if req.finished():  # It is aborted by AbortReq
                    num_ready_reqs += 1
                    continue
2041
                req.grammar = req.grammar.result(timeout=0.03)
2042
2043
2044
2045
2046
                self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
                if req.grammar is INVALID_GRAMMAR_OBJ:
                    req.set_finish_with_abort(
                        f"Invalid grammar request: {req.grammar_key=}"
                    )
2047
2048
                num_ready_reqs += 1
            except futures._base.TimeoutError:
2049
                req.grammar_wait_ct += 1
2050
2051
                # NOTE(lianmin): this timeout is the waiting time of the above line. It is
                # not the waiting time from it enters the grammar queue.
2052
                if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
2053
                    num_timeout_reqs = 1
2054
2055
                break

2056
        if self.server_args.enable_dp_attention:
2057
2058
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
2059
        else:
2060
2061
2062
2063
2064
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
2065
            tensor = torch.tensor([num_ready_reqs, num_timeout_reqs], dtype=torch.int32)
2066
2067
2068
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
2069
            num_ready_reqs_max, num_timeout_reqs_max = tensor.tolist()
2070

2071
            for i in range(num_ready_reqs, num_ready_reqs_max):
2072
                req = self.grammar_queue[i]
2073
2074
                if req.finished():  # It is aborted by AbortReq
                    continue
2075
                req.grammar = req.grammar.result()
2076
2077
2078
2079
2080
2081
2082
2083
                self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
                if req.grammar is INVALID_GRAMMAR_OBJ:
                    req.set_finish_with_abort(
                        f"Invalid grammar request: {req.grammar_key=}"
                    )
        else:
            num_ready_reqs_max = num_ready_reqs
            num_timeout_reqs_max = num_timeout_reqs
2084

2085
2086
2087
2088
2089
2090
2091
        for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
            req = self.grammar_queue[i]
            req.grammar.cancel()
            error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
            req.set_finish_with_abort(error_msg)
            self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
        num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
2092

2093
        self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
2094
2095
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

2096
2097
2098
2099
2100
2101
2102
    def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
        if batch.next_batch_sampling_info:
            if batch.next_batch_sampling_info.grammars is not None:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
                self.current_stream.synchronize()
            batch.next_batch_sampling_info.sampling_info_done.set()

2103
2104
2105
    def watchdog_thread(self):
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
        self.watchdog_last_forward_ct = 0
2106
        self.watchdog_last_time = time.perf_counter()
2107
2108

        while True:
2109
            current = time.perf_counter()
2110
2111
2112
2113
2114
2115
2116
2117
2118
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if current > self.watchdog_last_time + self.watchdog_timeout:
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)

Lianmin Zheng's avatar
Lianmin Zheng committed
2119
2120
        if not disable_request_logging():
            # Print batch size and memory pool info to check whether there are de-sync issues.
Hanming Lu's avatar
Hanming Lu committed
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
            if self.is_hybrid:
                (
                    _,
                    _,
                    _,
                    _,
                    full_available_size,
                    full_evictable_size,
                    swa_available_size,
                    swa_evictable_size,
                ) = self._get_swa_token_info()
                info_msg = (
                    f"{full_available_size=}, "
                    f"{full_evictable_size=}, "
                    f"{swa_available_size=}, "
                    f"{swa_evictable_size=}, "
                )
            else:
                _, _, available_size, evictable_size = self._get_token_info()
                info_msg = f"{available_size=}, " f"{evictable_size=}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
2141
2142
2143
            logger.error(
                f"{self.cur_batch.batch_size()=}, "
                f"{self.cur_batch.reqs=}, "
Hanming Lu's avatar
Hanming Lu committed
2144
                f"{info_msg}"
Lianmin Zheng's avatar
Lianmin Zheng committed
2145
2146
            )

2147
        pyspy_dump_schedulers()
Lianmin Zheng's avatar
Lianmin Zheng committed
2148
        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
2149
2150
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
2151
2152

        # Wait for some time so that the parent process can print the error.
2153
2154
2155
        time.sleep(5)
        self.parent_process.send_signal(signal.SIGQUIT)

2156
2157
2158
    def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
        success = self.flush_cache()
        return FlushCacheReqOutput(success=success)
2159

2160
    def flush_cache(self):
2161
        """Flush the memory pool and cache."""
2162
2163
2164
2165
2166
        if (
            len(self.waiting_queue) == 0
            and self.running_batch.is_empty()
            and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
        ):
2167
2168
            self.cur_batch = None
            self.last_batch = None
2169
            self.tree_cache.reset()
2170
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
2171
                self.grammar_backend.reset()
2172
            self.req_to_token_pool.clear()
2173
            self.token_to_kv_pool_allocator.clear()
2174
2175
2176

            if not self.spec_algorithm.is_none():
                self.draft_worker.model_runner.req_to_token_pool.clear()
2177
                self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
2178
2179
2180
2181
2182

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
2183
2184
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
2185
2186
2187
2188
2189
2190
2191
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
2192
                f"#running-req: {len(self.running_batch.reqs)}"
2193
2194
2195
2196
            )
            if_success = False
        return if_success

Liangsheng Yin's avatar
Liangsheng Yin committed
2197
2198
    def get_load(self):
        # TODO(lsyin): use dynamically maintained num_waiting_tokens
Hanming Lu's avatar
Hanming Lu committed
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
        if self.is_hybrid:
            load_full = (
                self.full_tokens_per_layer
                - self.token_to_kv_pool_allocator.full_available_size()
                - self.tree_cache.full_evictable_size()
            )
            load_swa = (
                self.swa_tokens_per_layer
                - self.token_to_kv_pool_allocator.swa_available_size()
                - self.tree_cache.swa_evictable_size()
            )
            load = max(load_full, load_swa)
        else:
            load = (
                self.max_total_num_tokens
                - self.token_to_kv_pool_allocator.available_size()
                - self.tree_cache.evictable_size()
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
        load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            load += sum(
                len(req.origin_input_ids)
                for req in self.disagg_prefill_bootstrap_queue.queue
            )
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            load += sum(
                len(req.req.origin_input_ids)
                for req in self.disagg_decode_prealloc_queue.queue
            )

        return load

2231
2232
2233
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
2234
2235
2236
2237
2238
2239
2240
2241
2242
        ret["memory_usage"] = {
            "weight": round(
                self.tp_worker.worker.model_runner.weight_load_mem_usage, 2
            ),
            "kvcache": round(
                self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
            ),
            "token_capacity": int(self.max_total_num_tokens),
        }
2243
2244
2245
2246
2247
2248

        if not _is_cpu:
            ret["memory_usage"]["cuda_graph"] = round(
                self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
            )

2249
2250
2251
2252
2253
2254
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )
        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
Liangsheng Yin's avatar
Liangsheng Yin committed
2255
2256
2257
2258

        ret["load"] = self.get_load()

        return GetInternalStateReqOutput(internal_state=ret)
2259
2260
2261
2262
2263

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
2264
                "max_micro_batch_size",
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
2275
2276
2277
2278
2279
2280
2281
2282
            elif k == "max_micro_batch_size" and (
                v > self.max_running_requests // self.pp_size or v < 1
            ):
                logging.warning(
                    f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
                )
                if_success = False
                break
2283
2284
2285
2286
2287
2288
2289
2290
2291
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
2292
            logger.info(f"Global server args updated! {global_server_args_dict=}")
2293
2294
2295
2296
2297
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
    def handle_rpc_request(self, recv_req: RpcReqInput):
        # Handle RPC requests
        logger.info(
            f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
        )

        success = True
        exec = None
        try:
            func = getattr(self, recv_req.method)
            func(recv_req.parameters)
        except Exception as e:
            success = False
            exec = e
            logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")

        barrier()
        return RpcReqOutput(success, "" if not exec else str(exec))

2317
2318
    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
Lianmin Zheng's avatar
Lianmin Zheng committed
2319
        to_del = []
2320
        for i, req in enumerate(self.waiting_queue):
2321
            if recv_req.abort_all or req.rid.startswith(recv_req.rid):
Lianmin Zheng's avatar
Lianmin Zheng committed
2322
                to_del.append(i)
2323

Lianmin Zheng's avatar
Lianmin Zheng committed
2324
        # Sort in reverse order to avoid index issues when deleting
Lianmin Zheng's avatar
Lianmin Zheng committed
2325
        for i in reversed(to_del):
2326
2327
2328
            # Abort method 1: directly pop from the queue
            # This only works for requests that have not started anything.
            # We still need to send something back to TokenizerManager to clean up the state.
Lianmin Zheng's avatar
Lianmin Zheng committed
2329
            req = self.waiting_queue.pop(i)
Lianmin Zheng's avatar
Lianmin Zheng committed
2330
            self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
2331
            logger.debug(f"Abort queued request. {req.rid=}")
2332

2333
2334
2335
2336
2337
        # Delete the requests in the grammar queue
        for req in self.grammar_queue:
            # Abort method 2: call `set_finish_with_abort`
            # The request will still run one prefill forward pass.
            # In this case, we change the input_ids to be only one token to make this prefill cheap.
2338
            if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2339
                logger.debug(f"Abort grammar queue request. {req.rid=}")
2340
2341
                if req.grammar:
                    req.grammar.cancel()
2342
2343
                req.set_finish_with_abort("Aborted by AbortReq.")

2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
        # Delete requests not in the waiting queue when PD disaggregation is enabled
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            # Abort requests that have not yet been bootstrapped
            for i, req in enumerate(self.disagg_prefill_bootstrap_queue.queue):
                logger.debug(f"Abort bootstrap queue request. {req.rid=}")
                if recv_req.abort_all or req.rid.startswith(recv_req.rid):
                    if hasattr(req.disagg_kv_sender, "abort"):
                        req.disagg_kv_sender.abort()

            # Abort in-flight requests
            for i, req in enumerate(self.disagg_prefill_inflight_queue):
                logger.debug(f"Abort inflight queue request. {req.rid=}")
                if recv_req.abort_all or req.rid.startswith(recv_req.rid):
                    if hasattr(req.disagg_kv_sender, "abort"):
                        req.disagg_kv_sender.abort()

        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            # Abort requests that have not yet finished preallocation
            for i, decode_req in enumerate(self.disagg_decode_prealloc_queue.queue):
                logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
                if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
                    if hasattr(decode_req.kv_receiver, "abort"):
                        decode_req.kv_receiver.abort()

            # Abort requests waiting for kvcache to release tree cache
            for i, decode_req in enumerate(self.disagg_decode_transfer_queue.queue):
                logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
                if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
                    if hasattr(decode_req.kv_receiver, "abort"):
                        decode_req.kv_receiver.abort()

2375
        # Delete requests in the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
2376
2377
2378
2379
2380
2381
        if self.cur_batch is self.running_batch or self.cur_batch is None:
            reqs = self.running_batch.reqs
        else:
            reqs = self.running_batch.reqs + self.cur_batch.reqs

        for req in reqs:
2382
2383
2384
            if not req.finished() and (
                recv_req.abort_all or req.rid.startswith(recv_req.rid)
            ):
2385
2386
2387
                # Abort method 3: set `to_abort=True`
                # The request will still run one decode forward pass.
                # Then we reuse all existing code to clean up the KV cache allocation.
Lianmin Zheng's avatar
Lianmin Zheng committed
2388
2389
                logger.debug(f"Abort running request. {req.rid=}")
                req.to_abort = True
2390

2391
2392
2393
    def _pause_engine(self) -> Tuple[List[Req], int]:
        raise NotImplementedError()

2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
    def load_lora_adapter(
        self, recv_req: LoadLoRAAdapterReqInput
    ) -> LoadLoRAAdapterReqOutput:
        """In-place loading a new lora adapter from disk or huggingface."""

        result = self.tp_worker.load_lora_adapter(recv_req)
        return result

    def unload_lora_adapter(
        self, recv_req: UnloadLoRAAdapterReqInput
    ) -> UnloadLoRAAdapterReqOutput:
        """Unload the lora adapter."""

        result = self.tp_worker.unload_lora_adapter(recv_req)
        return result

2410
2411
2412
2413
2414
2415
2416
    def slow_down(self, recv_req: SlowDownReqInput):
        t = recv_req.forward_sleep_time
        if t is not None and t <= 0:
            t = None
        self.forward_sleep_time = t
        return SlowDownReqOutput()

2417
2418
    def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
        if recv_req == ExpertDistributionReq.START_RECORD:
2419
            get_global_expert_distribution_recorder().start_record()
2420
        elif recv_req == ExpertDistributionReq.STOP_RECORD:
2421
            get_global_expert_distribution_recorder().stop_record()
2422
        elif recv_req == ExpertDistributionReq.DUMP_RECORD:
2423
            get_global_expert_distribution_recorder().dump_record()
2424
        else:
2425
            raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
2426
        return ExpertDistributionReqOutput()
2427

2428
    def open_session(self, recv_req: OpenSessionReqInput):
2429
2430
2431
2432
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
2433
            return OpenSessionReqOutput(session_id, False)
2434
        elif session_id is None:
2435
            logger.warning("session id is None, cannot open.")
2436
            return OpenSessionReqOutput(session_id, False)
2437
2438
2439
2440
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
2441
            return OpenSessionReqOutput(session_id, True)
2442
2443
2444
2445
2446
2447
2448
2449
2450

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

2451
2452
    def get_print_prefix(self):
        prefix = ""
2453
2454
        if self.attn_dp_rank is not None:
            prefix += f" DP{self.attn_dp_rank}"
2455
2456
2457
2458
2459
2460
        if self.server_args.tp_size > 1:
            prefix += f" TP{self.tp_rank}"
        if self.pp_size > 1:
            prefix += f" PP{self.pp_rank}"
        return prefix

2461
2462
    def current_scheduler_metrics_enabled(self):
        return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
2463

2464
2465
2466
    def maybe_sleep_on_idle(self):
        if self.idle_sleeper is not None:
            self.idle_sleeper.maybe_sleep()
2467

2468

2469
2470
2471
2472
2473
2474
2475
class IdleSleeper:
    """
    In setups which have long inactivity periods it is desirable to reduce
    system power consumption when sglang does nothing. This would lead not only
    to power savings, but also to more CPU thermal headroom when a request
    eventually comes. This is important in cases when multiple GPUs are connected
    as each GPU would otherwise pin one thread at 100% CPU usage.
2476

2477
2478
2479
    The simplest solution is to use zmq.Poller on all sockets that may receive
    data that needs handling immediately.
    """
2480

2481
2482
    def __init__(self, sockets):
        self.poller = zmq.Poller()
2483
        self.last_empty_time = time.time()
2484
2485
2486
2487
2488
        for s in sockets:
            self.poller.register(s, zmq.POLLIN)

    def maybe_sleep(self):
        self.poller.poll(1000)
2489
2490
2491
2492
2493
2494
2495
        if (
            global_config.torch_empty_cache_interval > 0
            and time.time() - self.last_empty_time
            > global_config.torch_empty_cache_interval
        ):
            self.last_empty_time = time.time()
            torch.cuda.empty_cache()
2496

2497

2498
2499
def is_health_check_generate_req(recv_req):
    return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
2500

2501
2502
2503

def is_work_request(recv_req):
    return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
2504
2505


2506
2507
2508
2509
2510
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
Cheng Wan's avatar
Cheng Wan committed
2511
    moe_ep_rank: int,
2512
    pp_rank: int,
2513
    dp_rank: Optional[int],
2514
    pipe_writer,
2515
    balance_meta: Optional[DPBalanceMeta] = None,
2516
):
2517
    # Generate the prefix
2518
2519
2520
2521
2522
    prefix = ""
    if dp_rank is not None:
        prefix += f" DP{dp_rank}"
    if server_args.tp_size > 1:
        prefix += f" TP{tp_rank}"
Cheng Wan's avatar
Cheng Wan committed
2523
2524
    if server_args.ep_size > 1:
        prefix += f" EP{moe_ep_rank}"
2525
2526
    if server_args.pp_size > 1:
        prefix += f" PP{pp_rank}"
2527

2528
    # Config the process
2529
    setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
2530
    faulthandler.enable()
2531
    kill_itself_when_parent_died()
2532
    parent_process = psutil.Process().parent()
2533

2534
2535
2536
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
2537

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
2538
    # Configure the logger
2539
    configure_logger(server_args, prefix=prefix)
2540
    suppress_other_loggers()
2541

2542
    # Set cpu affinity to this gpu process
2543
2544
2545
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

2546
    # Create a scheduler and run the event loop
2547
    try:
Cheng Wan's avatar
Cheng Wan committed
2548
        scheduler = Scheduler(
2549
2550
2551
2552
2553
2554
2555
2556
            server_args,
            port_args,
            gpu_id,
            tp_rank,
            moe_ep_rank,
            pp_rank,
            dp_rank,
            dp_balance_meta=balance_meta,
Cheng Wan's avatar
Cheng Wan committed
2557
        )
2558
        pipe_writer.send(
Mick's avatar
Mick committed
2559
2560
2561
2562
2563
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2564
        )
Byron Hsu's avatar
Byron Hsu committed
2565

2566
        disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
Byron Hsu's avatar
Byron Hsu committed
2567
        if disaggregation_mode == DisaggregationMode.NULL:
2568
2569
2570
            if server_args.pp_size > 1:
                scheduler.event_loop_pp()
            elif scheduler.enable_overlap:
Byron Hsu's avatar
Byron Hsu committed
2571
2572
2573
2574
                scheduler.event_loop_overlap()
            else:
                scheduler.event_loop_normal()
        elif disaggregation_mode == DisaggregationMode.PREFILL:
2575
2576
2577
2578
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_prefill()
            else:
                scheduler.event_loop_normal_disagg_prefill()
2579

Byron Hsu's avatar
Byron Hsu committed
2580
        elif disaggregation_mode == DisaggregationMode.DECODE:
2581
2582
2583
2584
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_decode()
            else:
                scheduler.event_loop_normal_disagg_decode()
Byron Hsu's avatar
Byron Hsu committed
2585

2586
    except Exception:
2587
2588
2589
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)