scheduler.py 116 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
import time
23
from collections import deque
Lianmin Zheng's avatar
Lianmin Zheng committed
24
from concurrent import futures
25
from dataclasses import dataclass
26
from http import HTTPStatus
27
from types import SimpleNamespace
28
from typing import Dict, List, Optional, Tuple, Union
29

30
import psutil
31
import setproctitle
32
import torch
33
import zmq
34
from torch.distributed import barrier
35

36
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
37
from sglang.srt.configs.model_config import ModelConfig
38
39
40
41
from sglang.srt.constrained.base_grammar_backend import (
    INVALID_GRAMMAR_OBJ,
    create_grammar_backend,
)
Byron Hsu's avatar
Byron Hsu committed
42
43
44
45
46
47
48
49
50
51
52
from sglang.srt.disaggregation.decode import (
    DecodePreallocQueue,
    DecodeTransferQueue,
    SchedulerDisaggregationDecodeMixin,
)
from sglang.srt.disaggregation.prefill import (
    PrefillBootstrapQueue,
    SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
53
    MetadataBuffers,
Byron Hsu's avatar
Byron Hsu committed
54
    ReqToMetadataIdxAllocator,
55
    TransferBackend,
56
    prepare_abort,
Byron Hsu's avatar
Byron Hsu committed
57
)
58
from sglang.srt.distributed import get_pp_group, get_world_group
fzyzcjy's avatar
fzyzcjy committed
59
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
xm:D's avatar
xm:D committed
60
61
62
63
64
from sglang.srt.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
65
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
66
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
67
from sglang.srt.layers.moe import initialize_moe_config
68
69
from sglang.srt.managers.io_struct import (
    AbortReq,
70
71
    BatchTokenizedEmbeddingReqInput,
    BatchTokenizedGenerateReqInput,
72
73
    ClearHiCacheReqInput,
    ClearHiCacheReqOutput,
74
    CloseSessionReqInput,
75
    DestroyWeightsUpdateGroupReqInput,
76
    ExpertDistributionReq,
77
    ExpertDistributionReqOutput,
78
79
    FlushCacheReqInput,
    FlushCacheReqOutput,
80
    FreezeGCReq,
81
82
    GetInternalStateReq,
    GetInternalStateReqOutput,
83
84
    GetLoadReqInput,
    GetLoadReqOutput,
85
    GetWeightsByNameReqInput,
86
    HealthCheckOutput,
87
88
    InitWeightsSendGroupForRemoteInstanceReqInput,
    InitWeightsSendGroupForRemoteInstanceReqOutput,
89
    InitWeightsUpdateGroupReqInput,
90
91
    LoadLoRAAdapterReqInput,
    LoadLoRAAdapterReqOutput,
92
    MultiTokenizerRegisterReq,
93
    MultiTokenizerWrapper,
94
95
    OpenSessionReqInput,
    OpenSessionReqOutput,
96
    ProfileReq,
97
98
    ReleaseMemoryOccupationReqInput,
    ResumeMemoryOccupationReqInput,
99
100
    RpcReqInput,
    RpcReqOutput,
101
102
    SendWeightsToRemoteInstanceReqInput,
    SendWeightsToRemoteInstanceReqOutput,
103
104
    SetInternalStateReq,
    SetInternalStateReqOutput,
105
106
    SlowDownReqInput,
    SlowDownReqOutput,
107
108
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
109
110
    UnloadLoRAAdapterReqInput,
    UnloadLoRAAdapterReqOutput,
Chayenne's avatar
Chayenne committed
111
    UpdateWeightFromDiskReqInput,
112
    UpdateWeightsFromDistributedReqInput,
113
    UpdateWeightsFromTensorReqInput,
114
)
115
from sglang.srt.managers.mm_utils import init_embedding_cache
116
117
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
Mick's avatar
Mick committed
118
    MultimodalInputs,
119
    Req,
120
    RequestStage,
121
    ScheduleBatch,
122
    global_server_args_dict,
123
)
124
125
126
127
128
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
fzyzcjy's avatar
fzyzcjy committed
129
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
130
131
132
133
from sglang.srt.managers.scheduler_metrics_mixin import (
    RECORD_STEP_TIME,
    SchedulerMetricsMixin,
)
134
135
136
from sglang.srt.managers.scheduler_output_processor_mixin import (
    SchedulerOutputProcessorMixin,
)
137
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
138
from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
139
140
141
from sglang.srt.managers.scheduler_update_weights_mixin import (
    SchedulerUpdateWeightsMixin,
)
142
from sglang.srt.managers.session_controller import Session
143
from sglang.srt.managers.tp_worker import TpModelWorker
144
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
145
from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
tarinkk's avatar
tarinkk committed
146
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
147
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
148
from sglang.srt.mem_cache.radix_cache import RadixCache
Hanming Lu's avatar
Hanming Lu committed
149
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
Lianmin Zheng's avatar
Lianmin Zheng committed
150
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
151
from sglang.srt.parser.reasoning_parser import ReasoningParser
152
from sglang.srt.server_args import PortArgs, ServerArgs
153
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
154
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
155
156
157
158
159
160
161
162
163
from sglang.srt.tracing.trace import (
    process_tracing_init,
    trace_event,
    trace_set_proc_propagate_context,
    trace_set_thread_info,
    trace_slice,
    trace_slice_end,
    trace_slice_start,
)
164
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
165
from sglang.srt.utils import (
166
    DynamicGradMode,
167
    broadcast_pyobj,
fzyzcjy's avatar
fzyzcjy committed
168
    configure_gc_logger,
169
    configure_logger,
Lianmin Zheng's avatar
Lianmin Zheng committed
170
    disable_request_logging,
171
    freeze_gc,
172
    get_available_gpu_memory,
173
    get_bool_env_var,
174
    get_int_env_var,
175
    get_zmq_socket,
176
    is_cpu,
Lianmin Zheng's avatar
Lianmin Zheng committed
177
    kill_itself_when_parent_died,
178
    numa_bind_to_node,
179
    point_to_point_pyobj,
180
    pyspy_dump_schedulers,
181
182
    require_mlp_sync,
    require_mlp_tp_gather,
183
    set_gpu_proc_affinity,
184
185
186
    set_random_seed,
    suppress_other_loggers,
)
187
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
188
189
190

logger = logging.getLogger(__name__)

191
# Test retract decode for debugging purposes
192
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
193
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
194

195
196
_is_cpu = is_cpu()

197

198
199
@dataclass
class GenerationBatchResult:
200
201
202
    logits_output: Optional[LogitsProcessorOutput]
    pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
    next_token_ids: Optional[List[int]]
203
204
    extend_input_len_per_req: List[int]
    extend_logprob_start_len_per_req: List[int]
205
    bid: int
206
    can_run_cuda_graph: bool
207
208
209
210
211
212
213
214


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


Byron Hsu's avatar
Byron Hsu committed
215
216
class Scheduler(
    SchedulerOutputProcessorMixin,
217
218
219
    SchedulerUpdateWeightsMixin,
    SchedulerProfilerMixin,
    SchedulerMetricsMixin,
Byron Hsu's avatar
Byron Hsu committed
220
221
222
    SchedulerDisaggregationDecodeMixin,
    SchedulerDisaggregationPrefillMixin,
):
223
224
225
226
227
228
229
230
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
Cheng Wan's avatar
Cheng Wan committed
231
        moe_ep_rank: int,
232
        pp_rank: int,
233
        dp_rank: Optional[int],
234
        dp_balance_meta: Optional[DPBalanceMeta] = None,
235
236
    ):
        # Parse args
237
        self.server_args = server_args
238
        self.tp_rank = tp_rank
Cheng Wan's avatar
Cheng Wan committed
239
        self.moe_ep_rank = moe_ep_rank
240
        self.pp_rank = pp_rank
241
        self.dp_rank = dp_rank
242
        self.tp_size = server_args.tp_size
Cheng Wan's avatar
Cheng Wan committed
243
        self.moe_ep_size = server_args.ep_size
244
245
        self.pp_size = server_args.pp_size
        self.dp_size = server_args.dp_size
246
        self.schedule_policy = server_args.schedule_policy
247
248
249
250
251
252
253
        self.enable_priority_scheduling = server_args.enable_priority_scheduling
        self.schedule_low_priority_values_first = (
            server_args.schedule_low_priority_values_first
        )
        self.priority_scheduling_preemption_threshold = (
            server_args.priority_scheduling_preemption_threshold
        )
254
        self.enable_lora = server_args.enable_lora
255
        self.max_loras_per_batch = server_args.max_loras_per_batch
256
        self.enable_overlap = not server_args.disable_overlap_schedule
257
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
258
        self.enable_metrics = server_args.enable_metrics
259
260
261
        self.enable_metrics_for_all_schedulers = (
            server_args.enable_metrics_for_all_schedulers
        )
262
        self.enable_kv_cache_events = server_args.kv_events_config and tp_rank == 0
263
        self.stream_interval = server_args.stream_interval
264
265
266
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
267
268
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
269
        self.enable_hicache_storage = server_args.hicache_storage_backend is not None
Lianmin Zheng's avatar
Lianmin Zheng committed
270
        self.page_size = server_args.page_size
271

272
        self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
273
274
275
276
277
278
279
280
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

281
282
283
        # Init model config
        self.model_config = ModelConfig.from_server_args(server_args)

284
285
        # Init inter-process communication
        context = zmq.Context(2)
286
        self.idle_sleeper = None
287
        if self.pp_rank == 0 and self.attn_tp_rank == 0:
288
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
289
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
290
            )
291
292
293
294
            self.recv_from_rpc = get_zmq_socket(
                context, zmq.DEALER, port_args.rpc_ipc_name, False
            )

295
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
296
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
297
            )
298
            if server_args.skip_tokenizer_init:
299
                # Directly send to the TokenizerManager
300
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
301
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
302
303
                )
            else:
304
                # Send to the DetokenizerManager
305
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
306
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
307
                )
308

309
310
311
312
313
314
315
            if self.server_args.sleep_on_idle:
                self.idle_sleeper = IdleSleeper(
                    [
                        self.recv_from_tokenizer,
                        self.recv_from_rpc,
                    ]
                )
316
        else:
317
            self.recv_from_tokenizer = None
318
            self.recv_from_rpc = None
319
320
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
321

322
323
324
325
326
        if self.current_scheduler_metrics_enabled():
            self.send_metrics_from_scheduler = get_zmq_socket(
                context, zmq.PUSH, port_args.metrics_ipc_name, False
            )

327
        # Init tokenizer
328
        self.init_tokenizer()
329

330
331
332
        # Init moe config
        self.init_moe_config()

333
334
335
336
337
338
339
340
341
        # Set reasoning_parser and think_end_id if --reasoning_parser is enabled
        if self.server_args.reasoning_parser and self.tokenizer:
            reasoning_parser = ReasoningParser(
                model_type=self.server_args.reasoning_parser, stream_reasoning=False
            )
            self.tokenizer.think_end_id = self.tokenizer.encode(
                reasoning_parser.detector.think_end_token, add_special_tokens=False
            )[0]

342
343
344
345
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
346

347
        # Launch a tensor parallel worker
348
        if self.enable_overlap:
349
            TpWorkerClass = TpModelWorkerClient
350
351
        else:
            TpWorkerClass = TpModelWorker
352

353
        self.tp_worker = TpWorkerClass(
354
            server_args=server_args,
355
356
            gpu_id=gpu_id,
            tp_rank=tp_rank,
Cheng Wan's avatar
Cheng Wan committed
357
            moe_ep_rank=moe_ep_rank,
358
            pp_rank=pp_rank,
359
            dp_rank=dp_rank,
360
            nccl_port=port_args.nccl_port,
361
        )
362

363
        # Launch a draft worker for speculative decoding
364
365
366
367
368
369
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
Cheng Wan's avatar
Cheng Wan committed
370
                moe_ep_rank=moe_ep_rank,
371
372
373
374
375
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
376
377
378
379
380
381
382
383
384
385
386
387
        elif self.spec_algorithm.is_standalone():
            from sglang.srt.speculative.standalone_worker import StandaloneWorker

            self.draft_worker = StandaloneWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                moe_ep_rank=moe_ep_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
388
389
390
391
392
393
394
395
396
397
398
399
        elif self.spec_algorithm.is_lookahead():
            from sglang.srt.speculative.lookahead_worker import LOOKAHEADWorker

            self.draft_worker = LOOKAHEADWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                moe_ep_rank=moe_ep_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
400
401
402
        else:
            self.draft_worker = None

403
        # Get token and memory info from the model worker
404
405
406
407
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
408
            self.max_queued_requests,
409
            self.max_req_len,
410
411
            self.max_req_input_len,
            self.random_seed,
412
            self.device,
413
414
415
416
417
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
418
419
420
421
422
423
424
425
        if global_server_args_dict["max_micro_batch_size"] is None:
            global_server_args_dict["max_micro_batch_size"] = max(
                self.max_running_requests // server_args.pp_size, 1
            )

        self.tp_group = self.tp_worker.get_tp_group()
        self.tp_cpu_group = self.tp_group.cpu_group
        self.attn_tp_group = self.tp_worker.get_attention_tp_group()
426
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
427
428
429
        self.pp_group = get_pp_group()
        self.world_group = get_world_group()

430
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
431
        global_server_args_dict.update(worker_global_server_args_dict)
432
        set_random_seed(self.random_seed)
433

434
        # Hybrid memory pool
Hanming Lu's avatar
Hanming Lu committed
435
436
437
438
439
440
441
        self.is_hybrid = self.tp_worker.is_hybrid
        if self.is_hybrid:
            self.sliding_window_size = self.tp_worker.sliding_window_size
            self.full_tokens_per_layer, self.swa_tokens_per_layer = (
                self.tp_worker.get_tokens_per_layer_info()
            )

442
        # Print debug info
443
        if tp_rank == 0:
444
445
446
            avail_mem = get_available_gpu_memory(
                self.device, self.gpu_id, empty_cache=False
            )
447
448
449
450
451
            logger.info(
                f"max_total_num_tokens={self.max_total_num_tokens}, "
                f"chunked_prefill_size={server_args.chunked_prefill_size}, "
                f"max_prefill_tokens={self.max_prefill_tokens}, "
                f"max_running_requests={self.max_running_requests}, "
452
                f"context_len={self.model_config.context_len}, "
453
                f"{'available_cpu_mem' if self.device == 'cpu' else 'available_gpu_mem'}={avail_mem:.2f} GB"
454
            )
455

Lianmin Zheng's avatar
Lianmin Zheng committed
456
        # Init memory pool and cache
457
        self.init_memory_pool_and_cache()
458
459
460

        # Init running status
        self.waiting_queue: List[Req] = []
461
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
462
        self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
463
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
464
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
465
        # The last forward batch
466
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
467
468
        self.forward_ct = 0
        self.forward_ct_decode = 0
469
        self.num_generated_tokens = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
470
        self.last_prefill_tokens = 0
471
472
        self.last_decode_stats_tic = time.perf_counter()
        self.last_prefill_stats_tic = time.perf_counter()
473
        self.return_health_check_ct = 0
474
475
476
477
478
        self.num_retracted_reqs: int = 0
        self.num_paused_reqs: int = 0
        self.kv_transfer_speed_gb_s: float = 0.0
        self.kv_transfer_latency_ms: float = 0.0
        self.sessions: Dict[str, Session] = {}
479
        self.current_stream = torch.get_device_module(self.device).current_stream()
480
481
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
482
        self.forward_sleep_time = None
483

484
485
        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
486
487
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
488
        self.chunked_req = None
489
490
491
492
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
493
        # Init the grammar backend for constrained generation
494
        self.grammar_queue: List[Req] = []
495
        if not server_args.skip_tokenizer_init:
496
            self.grammar_backend = create_grammar_backend(
497
498
499
500
                server_args,
                self.tokenizer,
                self.model_config.vocab_size,
                self.model_config.hf_eos_token_id,
501
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
502
503
        else:
            self.grammar_backend = None
504

505
        # Init schedule policy and new token estimation
506
        self.policy = SchedulePolicy(
Lianmin Zheng's avatar
Lianmin Zheng committed
507
508
509
            self.schedule_policy,
            self.tree_cache,
            self.enable_hierarchical_cache,
510
511
            self.enable_priority_scheduling,
            self.schedule_low_priority_values_first,
512
        )
513
514
515
        # Enable preemption for priority scheduling.
        self.try_preemption = self.enable_priority_scheduling

516
517
518
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
519
520
        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
521
522
            * server_args.schedule_conservativeness,
            1.0,
523
        )
524
525
526
527
528
529
530
531
532
533
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
534
535
536
537
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
538
        self.parent_process = psutil.Process().parent()
539
540

        # Init memory saver, profiler and metric stats
541
542
543
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )
544
        self.offload_tags = set()
limingshu's avatar
limingshu committed
545
        self.init_profiler()
546

547
        self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
fzyzcjy's avatar
fzyzcjy committed
548
549
550
551
552
553
        self.input_blocker = (
            SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
            if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
            else None
        )

554
        # Init metrics stats
555
        self.init_metrics(tp_rank, pp_rank, dp_rank)
556
        self.init_dp_balance(dp_balance_meta)
557

558
559
560
        if self.enable_kv_cache_events:
            self.init_kv_events(server_args.kv_events_config)

561
562
563
564
565
566
567
568
569
        # Init disaggregation
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
        self.init_disaggregation()

        if get_bool_env_var("SGLANG_GC_LOG"):
            configure_gc_logger()

570
571
        # Init prefill kv split size when deterministic inference is enabled with various attention backends
        self.init_deterministic_inference_config()
572

573
574
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
575
576
577
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
578
579
                (BatchTokenizedGenerateReqInput, self.handle_batch_generate_request),
                (BatchTokenizedEmbeddingReqInput, self.handle_batch_embedding_request),
580
                (FlushCacheReqInput, self.flush_cache_wrapped),
581
                (ClearHiCacheReqInput, self.clear_hicache_storage_wrapped),
582
                (AbortReq, self.abort_request),
583
584
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
585
586
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
587
                (DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group),
588
589
590
591
592
593
594
595
                (
                    InitWeightsSendGroupForRemoteInstanceReqInput,
                    self.init_weights_send_group_for_remote_instance,
                ),
                (
                    SendWeightsToRemoteInstanceReqInput,
                    self.send_weights_to_remote_instance,
                ),
596
597
598
599
600
601
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
602
603
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
604
                (SlowDownReqInput, self.slow_down),
605
                (ProfileReq, self.profile),
606
                (FreezeGCReq, self.handle_freeze_gc),
607
                (GetInternalStateReq, self.get_internal_state),
608
                (SetInternalStateReq, self.set_internal_state),
609
                (RpcReqInput, self.handle_rpc_request),
610
                (ExpertDistributionReq, self.expert_distribution_handle),
611
612
                (LoadLoRAAdapterReqInput, self.load_lora_adapter),
                (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
613
                (MultiTokenizerRegisterReq, self.register_multi_tokenizer),
614
                (GetLoadReqInput, self.get_load),
615
616
617
            ]
        )

618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
    def init_deterministic_inference_config(self):
        """Initialize deterministic inference configuration for different attention backends."""
        if not self.server_args.enable_deterministic_inference:
            self.truncation_align_size = None
            return

        backend_sizes = {
            "flashinfer": ("SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096),
            "triton": ("SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", 4096),
        }
        env_var, default_size = backend_sizes.get(
            self.server_args.attention_backend, (None, None)
        )
        self.truncation_align_size = (
            get_int_env_var(env_var, default_size) if env_var else None
        )

635
636
637
    def init_tokenizer(self):
        server_args = self.server_args
        self.is_generation = self.model_config.is_generation
638

639
640
641
642
643
644
645
646
647
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if self.model_config.is_multimodal:
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
648
                    use_fast=not server_args.disable_fast_image_processor,
649
                )
xm:D's avatar
xm:D committed
650
                self.tokenizer = get_tokenizer_from_processor(self.processor)
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )

    def init_memory_pool_and_cache(self):
        server_args = self.server_args

        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            self.tp_worker.get_memory_pool()
        )

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
Hanming Lu's avatar
Hanming Lu committed
670
            if self.is_hybrid:
tarinkk's avatar
tarinkk committed
671
672
673
674
                ChunkCacheClass = SWAChunkCache
            else:
                ChunkCacheClass = ChunkCache
            self.tree_cache = ChunkCacheClass(
675
676
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
677
                page_size=self.page_size,
678
679
            )
        else:
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
            if os.environ.get("SGLANG_EXPERIMENTAL_CPP_RADIX_TREE") == "1":
                # lazy import to avoid JIT overhead
                from sglang.srt.mem_cache.radix_cache_cpp import RadixCacheCpp

                self.tree_cache = RadixCacheCpp(
                    disable=False,
                    use_hicache=self.enable_hierarchical_cache,
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool=self.token_to_kv_pool_allocator,
                    tp_cache_group=self.tp_cpu_group,
                    page_size=self.page_size,
                    hicache_ratio=server_args.hicache_ratio,
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
                    enable_kv_cache_events=self.enable_kv_cache_events,
                )
            elif self.enable_hierarchical_cache:
697
698
699
                self.tree_cache = HiRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
700
701
702
703
704
                    tp_cache_group=(
                        self.attn_tp_cpu_group
                        if self.server_args.enable_dp_attention
                        else self.tp_cpu_group
                    ),
705
                    page_size=self.page_size,
706
                    eviction_policy=server_args.radix_eviction_policy,
707
                    hicache_ratio=server_args.hicache_ratio,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
708
709
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
710
                    hicache_io_backend=server_args.hicache_io_backend,
711
                    hicache_mem_layout=server_args.hicache_mem_layout,
712
                    enable_metrics=self.enable_metrics,
713
                    hicache_storage_backend=server_args.hicache_storage_backend,
pansicheng's avatar
pansicheng committed
714
                    hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
715
716
                    model_name=server_args.served_model_name,
                    storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
717
                )
718
719
720
                self.tp_worker.register_hicache_layer_transfer_counter(
                    self.tree_cache.cache_controller.layer_done_counter
                )
Hanming Lu's avatar
Hanming Lu committed
721
722
723
724
725
726
727
728
729
730
731
            elif self.is_hybrid:
                assert (
                    self.server_args.disaggregation_mode == "null"
                ), "Hybrid mode does not support disaggregation yet"
                self.tree_cache = SWARadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                    sliding_window_size=self.sliding_window_size,
                    page_size=self.page_size,
                    disable=server_args.disable_radix_cache,
                )
732
733
734
735
736
737
738
739
740
741
742
743
744
745
            elif server_args.enable_lmcache:
                from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
                    LMCRadixCache,
                )

                self.tree_cache = LMCRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                    page_size=self.page_size,
                    disable=server_args.disable_radix_cache,
                    model_config=self.model_config,
                    tp_size=self.tp_size,
                    rank=self.tp_rank,
                    tp_group=self.tp_group,
746
                    eviction_policy=server_args.radix_eviction_policy,
747
                )
748
749
750
751
            else:
                self.tree_cache = RadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
752
                    page_size=self.page_size,
753
                    disable=server_args.disable_radix_cache,
754
                    enable_kv_cache_events=self.enable_kv_cache_events,
755
                    eviction_policy=server_args.radix_eviction_policy,
756
757
758
759
760
761
762
763
                )

        self.decode_mem_cache_buf_multiplier = (
            1
            if self.spec_algorithm.is_none()
            else (
                server_args.speculative_num_draft_tokens
                + (
764
765
                    (server_args.speculative_eagle_topk or 1)
                    * (server_args.speculative_num_steps or 1)
766
767
                )
            )
768
        )
769

770
771
772
        embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
        init_embedding_cache(embedding_cache_size * 1024 * 1024)

Byron Hsu's avatar
Byron Hsu committed
773
    def init_disaggregation(self):
774
775
776
777
        self.transfer_backend = TransferBackend(
            self.server_args.disaggregation_transfer_backend
        )

Byron Hsu's avatar
Byron Hsu committed
778
779
780
781
        if (
            self.disaggregation_mode == DisaggregationMode.DECODE
        ):  # *2 for the headroom.
            buffer_size = (self.req_to_token_pool.size) * 2
Byron Hsu's avatar
Byron Hsu committed
782
            self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
Byron Hsu's avatar
Byron Hsu committed
783
784
                buffer_size
            )
785
786
            self.disagg_metadata_buffers = MetadataBuffers(
                buffer_size,
787
788
                hidden_size=self.model_config.hf_text_config.hidden_size,
                dtype=self.model_config.dtype,
789
790
                custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
            )
Byron Hsu's avatar
Byron Hsu committed
791
792
793

            # The decode requests polling kv cache
            self.disagg_decode_transfer_queue = DecodeTransferQueue(
794
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
795
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
796
                tp_rank=self.tp_rank,
797
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
798
799
                scheduler=self,
                tree_cache=self.tree_cache,
Byron Hsu's avatar
Byron Hsu committed
800
801
802
803
804
805
            )

            # The decode requests pending for pre-allocation
            self.disagg_decode_prealloc_queue = DecodePreallocQueue(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Byron Hsu's avatar
Byron Hsu committed
806
807
                draft_token_to_kv_pool=(
                    None
808
                    if self.draft_worker is None or self.spec_algorithm.is_lookahead()
Byron Hsu's avatar
Byron Hsu committed
809
810
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
811
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
812
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
813
814
815
                scheduler=self,
                transfer_queue=self.disagg_decode_transfer_queue,
                tree_cache=self.tree_cache,
816
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
817
818
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
819
820
                dp_size=self.server_args.dp_size,
                gpu_id=self.gpu_id,
Byron Hsu's avatar
Byron Hsu committed
821
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
822
823
                max_total_num_tokens=self.max_total_num_tokens,
                prefill_pp_size=self.server_args.disaggregation_prefill_pp,
824
                num_reserved_decode_tokens=self.server_args.num_reserved_decode_tokens,
825
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
826
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
827

Byron Hsu's avatar
Byron Hsu committed
828
829
830
        elif self.disaggregation_mode == DisaggregationMode.PREFILL:
            # *2 for the headroom.
            buffer_size = self.max_running_requests * 2
Byron Hsu's avatar
Byron Hsu committed
831
            self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
Byron Hsu's avatar
Byron Hsu committed
832
833
                buffer_size
            )
834
835
            self.disagg_metadata_buffers = MetadataBuffers(
                buffer_size,
836
837
                hidden_size=self.model_config.hf_text_config.hidden_size,
                dtype=self.model_config.dtype,
838
839
                custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
            )
Byron Hsu's avatar
Byron Hsu committed
840

Liangsheng Yin's avatar
Liangsheng Yin committed
841
            self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
Byron Hsu's avatar
Byron Hsu committed
842
                token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
Byron Hsu's avatar
Byron Hsu committed
843
844
                draft_token_to_kv_pool=(
                    None
845
                    if self.draft_worker is None or self.spec_algorithm.is_lookahead()
Byron Hsu's avatar
Byron Hsu committed
846
847
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
848
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
849
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
850
851
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
Byron Hsu's avatar
Byron Hsu committed
852
                gpu_id=self.gpu_id,
Byron Hsu's avatar
Byron Hsu committed
853
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
854
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
855
856
857
                max_total_num_tokens=self.max_total_num_tokens,
                decode_tp_size=self.server_args.disaggregation_decode_tp,
                decode_dp_size=self.server_args.disaggregation_decode_dp,
858
                scheduler=self,
Byron Hsu's avatar
Byron Hsu committed
859
860
861
                pp_rank=self.pp_rank,
                pp_size=self.pp_size,
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
862
863
            )
            # The prefill requests that are in the middle of kv sending
864
            self.disagg_prefill_inflight_queue: List[Req] = []
Byron Hsu's avatar
Byron Hsu committed
865

866
867
868
869
    def init_moe_config(self):
        if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
            initialize_moe_config(self.server_args)

870
    @DynamicGradMode()
871
    def event_loop_normal(self):
872
        """A normal scheduler loop."""
873
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
874
875
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
876

877
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
878
            self.cur_batch = batch
879

880
881
882
883
            if batch:
                for req in batch.reqs:
                    trace_event("schedule", req.rid)

884
885
886
            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
887
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
888
                # When the server is idle, do self-check and re-init some states
889
                self.self_check_during_idle()
890
891

            self.last_batch = batch
892

893
    @DynamicGradMode()
Lianmin Zheng's avatar
Lianmin Zheng committed
894
    def event_loop_overlap(self):
895
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
896
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
897
898
899
900
901
902
903

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
904

905
906
907
908
            if batch:
                for req in batch.reqs:
                    trace_event("schedule", req.rid)

Lianmin Zheng's avatar
Lianmin Zheng committed
909
            if batch:
910
                batch.launch_done = threading.Event()
Lianmin Zheng's avatar
Lianmin Zheng committed
911
                result = self.run_batch(batch)
912
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
913

914
                if self.last_batch is None:
915
                    # Create a dummy first batch to start the pipeline for overlap schedule.
916
917
918
919
920
921
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
922
                    self.process_batch_result(tmp_batch, None, batch.launch_done)
923

Lianmin Zheng's avatar
Lianmin Zheng committed
924
            if self.last_batch:
925
                # Process the results of the last batch
926
                tmp_batch, tmp_result = self.result_queue.popleft()
927
928
929
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
930
931
932
933
                # NOTE: we should use current launched batch's launch_done event Instead of the last batch's
                self.process_batch_result(
                    tmp_batch, tmp_result, batch.launch_done if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
934
            elif batch is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
935
                # When the server is idle, do self-check and re-init some states
936
                self.self_check_during_idle()
Lianmin Zheng's avatar
Lianmin Zheng committed
937
938
939

            self.last_batch = batch

940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
    @DynamicGradMode()
    def event_loop_pp(self):
        """A non-overlap scheduler loop for pipeline parallelism."""
        mbs = [None] * self.pp_size
        last_mbs = [None] * self.pp_size
        self.running_mbs = [
            ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
        ]
        bids = [None] * self.pp_size
        pp_outputs: Optional[PPProxyTensors] = None
        while True:
            server_is_idle = True
            for mb_id in range(self.pp_size):
                self.running_batch = self.running_mbs[mb_id]
                self.last_batch = last_mbs[mb_id]

                recv_reqs = self.recv_requests()
                self.process_input_requests(recv_reqs)
                mbs[mb_id] = self.get_next_batch_to_run()
                self.running_mbs[mb_id] = self.running_batch

                self.cur_batch = mbs[mb_id]
                if self.cur_batch:
                    server_is_idle = False
                    result = self.run_batch(self.cur_batch)

966
                # (last rank) send the outputs to the next step
967
968
969
970
971
972
                if self.pp_group.is_last_rank:
                    if self.cur_batch:
                        next_token_ids, bids[mb_id] = (
                            result.next_token_ids,
                            result.bid,
                        )
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
                        if self.cur_batch.return_logprob:
                            pp_outputs = PPProxyTensors(
                                {
                                    "next_token_ids": next_token_ids,
                                    "extend_input_len_per_req": result.extend_input_len_per_req,
                                    "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
                                }
                                | (
                                    {
                                        f"logits_output.{k}": v
                                        for k, v in result.logits_output.__dict__.items()
                                    }
                                    if result.logits_output is not None
                                    else {}
                                )
                            )
                        else:
                            pp_outputs = PPProxyTensors(
                                {
                                    "next_token_ids": next_token_ids,
                                }
                            )
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
                        # send the output from the last round to let the next stage worker run post processing
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                # receive outputs and post-process (filter finished reqs) the coming microbatch
                next_mb_id = (mb_id + 1) % self.pp_size
                next_pp_outputs = None
                if mbs[next_mb_id] is not None:
                    next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
                        self.pp_group.recv_tensor_dict(
                            all_gather_group=self.attn_tp_group
                        )
                    )
                    mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
1011
1012
1013
1014
1015
1016
1017
1018
1019
                    logits_output_args = {
                        k[len("logits_output.") :]: v
                        for k, v in next_pp_outputs.tensors.items()
                        if k.startswith("logits_output.")
                    }
                    if len(logits_output_args) > 0:
                        logits_output = LogitsProcessorOutput(**logits_output_args)
                    else:
                        logits_output = None
1020
                    output_result = GenerationBatchResult(
1021
                        logits_output=logits_output,
1022
1023
                        pp_hidden_states_proxy_tensors=None,
                        next_token_ids=next_pp_outputs["next_token_ids"],
1024
1025
1026
1027
1028
1029
                        extend_input_len_per_req=next_pp_outputs.tensors.get(
                            "extend_input_len_per_req", None
                        ),
                        extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
                            "extend_logprob_start_len_per_req", None
                        ),
1030
                        bid=bids[next_mb_id],
1031
                        can_run_cuda_graph=result.can_run_cuda_graph,
1032
1033
1034
1035
                    )
                    self.process_batch_result(mbs[next_mb_id], output_result)
                    last_mbs[next_mb_id] = mbs[next_mb_id]

1036
                # (not last rank)
1037
1038
1039
                if not self.pp_group.is_last_rank:
                    if self.cur_batch:
                        bids[mb_id] = result.bid
1040
1041
                    # carry the outputs to the next stage
                    # send the outputs from the last round to let the next stage worker run post processing
1042
1043
1044
1045
1046
1047
1048
                    if pp_outputs:
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                    # send out reqs to the next stage
1049
                    dp_offset = self.attn_dp_rank * self.attn_tp_size
1050
1051
1052
1053
                    if self.attn_tp_rank == 0:
                        point_to_point_pyobj(
                            recv_reqs,
                            self.pp_rank * self.tp_size + dp_offset,
1054
                            self.world_group.device_group,
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
                            self.pp_rank * self.tp_size + dp_offset,
                            (self.pp_rank + 1) * self.tp_size + dp_offset,
                        )

                    # send out proxy tensors to the next stage
                    if self.cur_batch:
                        self.pp_group.send_tensor_dict(
                            result.pp_hidden_states_proxy_tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                pp_outputs = next_pp_outputs

            # When the server is idle, self-check and re-init some states
            if server_is_idle:
1070
1071
                # When the server is idle, do self-check and re-init some states
                self.self_check_during_idle()
1072

1073
1074
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
1075
1076
1077
1078
1079
1080
1081
1082

        if self.recv_skipper is not None:
            last_forward_mode = (
                self.last_batch.forward_mode if self.last_batch is not None else None
            )
            if not self.recv_skipper.handle(last_forward_mode):
                return []

1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
        if self.pp_rank == 0:
            if self.attn_tp_rank == 0:
                recv_reqs = []

                while True:
                    try:
                        recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_req)

                while True:
                    try:
                        recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_rpc)
            else:
                recv_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1102
        else:
1103
            if self.attn_tp_rank == 0:
1104
                dp_offset = self.attn_dp_rank * self.attn_tp_size
1105
1106
1107
                recv_reqs = point_to_point_pyobj(
                    [],
                    self.pp_rank * self.tp_size + dp_offset,
1108
                    self.world_group.device_group,
1109
1110
1111
1112
1113
                    (self.pp_rank - 1) * self.tp_size + dp_offset,
                    self.pp_rank * self.tp_size + dp_offset,
                )
            else:
                recv_reqs = None
1114

fzyzcjy's avatar
fzyzcjy committed
1115
1116
1117
        if self.input_blocker is not None:
            recv_reqs = self.input_blocker.handle(recv_reqs)

1118
1119
1120
1121
1122
1123
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
1124
1125
1126
1127
1128
1129
1130
                        req,
                        (
                            TokenizedGenerateReqInput,
                            TokenizedEmbeddingReqInput,
                            BatchTokenizedGenerateReqInput,
                            BatchTokenizedEmbeddingReqInput,
                        ),
1131
1132
1133
1134
1135
1136
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
1137
1138
1139
1140
1141
1142
1143
                        req,
                        (
                            TokenizedGenerateReqInput,
                            TokenizedEmbeddingReqInput,
                            BatchTokenizedGenerateReqInput,
                            BatchTokenizedEmbeddingReqInput,
                        ),
1144
1145
1146
1147
1148
1149
1150
1151
1152
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                work_reqs = broadcast_pyobj(
                    work_reqs,
1153
                    self.attn_tp_group.rank,
1154
                    self.attn_tp_cpu_group,
1155
                    src=self.attn_tp_group.ranks[0],
1156
1157
1158
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
1159
1160
1161
1162
                    control_reqs,
                    self.tp_group.rank,
                    self.tp_cpu_group,
                    src=self.tp_group.ranks[0],
1163
1164
1165
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
1166
1167
1168
1169
1170
1171
            recv_reqs = broadcast_pyobj(
                recv_reqs,
                self.tp_group.rank,
                self.tp_cpu_group,
                src=self.tp_group.ranks[0],
            )
1172
1173
1174
1175
1176
1177

        for req in recv_reqs:
            if isinstance(req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)):
                trace_set_proc_propagate_context(req.rid, req.trace_context)
                trace_slice_start("", req.rid, anonymous=True)

1178
1179
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
1180
    def process_input_requests(self, recv_reqs: List):
1181
        for recv_req in recv_reqs:
1182
1183
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
1184
1185
1186
                self.chunked_req is not None
                or not self.running_batch.is_empty()
                or len(self.offload_tags) > 0
1187
1188
1189
1190
            ):
                self.return_health_check_ct += 1
                continue

1191
1192
            # If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
            if isinstance(recv_req, MultiTokenizerWrapper):
1193
1194
1195
1196
                worker_id = recv_req.worker_id
                recv_req = recv_req.obj
                output = self._request_dispatcher(recv_req)
                if output is not None:
1197
                    output = MultiTokenizerWrapper(worker_id, output)
1198
1199
1200
                    self.send_to_tokenizer.send_pyobj(output)
                continue

1201
            output = self._request_dispatcher(recv_req)
1202
            if output is not None:
1203
1204
1205
1206
1207
                if isinstance(output, RpcReqOutput):
                    if self.recv_from_rpc is not None:
                        self.recv_from_rpc.send_pyobj(output)
                else:
                    self.send_to_tokenizer.send_pyobj(output)
1208

1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
    def init_req_max_new_tokens(self, req):
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
            self.max_req_len - len(req.origin_input_ids) - 1,
        )

1219
1220
1221
1222
    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
1223
        self.maybe_update_dp_balance_data(recv_req)
1224

1225
        # Create a new request
1226
1227
1228
1229
1230
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
1231
1232
1233
1234
1235
1236
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

1237
1238
1239
1240
            if recv_req.bootstrap_port is None:
                # Use default bootstrap port
                recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port

1241
1242
1243
1244
1245
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
1246
1247
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
1248
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
1249
                stream=recv_req.stream,
1250
                lora_id=recv_req.lora_id,
Rin Intachuen's avatar
Rin Intachuen committed
1251
                input_embeds=recv_req.input_embeds,
Lianmin Zheng's avatar
Lianmin Zheng committed
1252
                custom_logit_processor=recv_req.custom_logit_processor,
1253
                return_hidden_states=recv_req.return_hidden_states,
1254
                eos_token_ids=self.model_config.hf_eos_token_id,
1255
                bootstrap_host=recv_req.bootstrap_host,
1256
                bootstrap_port=recv_req.bootstrap_port,
1257
                bootstrap_room=recv_req.bootstrap_room,
1258
                data_parallel_rank=recv_req.data_parallel_rank,
1259
                vocab_size=self.model_config.vocab_size,
1260
                priority=recv_req.priority,
1261
1262
1263
                metrics_collector=(
                    self.metrics_collector if self.enable_metrics else None
                ),
1264
1265
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
1266

1267
1268
1269
            if self.disaggregation_mode != DisaggregationMode.NULL:
                # Invalid request for disaggregated mode
                if recv_req.bootstrap_room is None:
1270
                    error_msg = (
1271
1272
1273
                        f"Invalid request: Disaggregated request received without "
                        f"boostrap room id. {req.rid=}"
                    )
1274
                    logger.error(error_msg)
1275
                    prepare_abort(req, error_msg, status_code=HTTPStatus.BAD_REQUEST)
1276
1277
1278
                    self.stream_output([req], req.return_logprob)
                    return

1279
1280
1281
1282
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
1283
                req.set_finish_with_abort(
1284
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
1285
                )
1286
                self.init_req_max_new_tokens(req)
1287
                self._add_request_to_queue(req)
1288
1289
                return
        else:
1290
1291
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
1292
            req = session.create_req(recv_req, self.tokenizer)
1293
            if isinstance(req.finished_reason, FINISH_ABORT):
1294
                self.init_req_max_new_tokens(req)
1295
                self._add_request_to_queue(req)
1296
                return
1297

1298
        # Handle multimodal inputs
Mick's avatar
Mick committed
1299
1300
        if recv_req.mm_inputs is not None:
            image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
1301
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
1302
            req.origin_input_ids = self.pad_input_ids_func(
1303
                req.origin_input_ids, image_inputs
1304
            )
1305
            req.extend_image_inputs(image_inputs)
1306

1307
            if len(req.origin_input_ids) >= self.max_req_input_len:
1308
1309
1310
1311
1312
                req.set_finish_with_abort(
                    error_msg=(
                        "Multimodal prompt is too long after expanding multimodal tokens. "
                        f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                    )
1313
                )
1314
                self.init_req_max_new_tokens(req)
1315
                self._add_request_to_queue(req)
1316
1317
                return

1318
1319
1320
        # initialize before returning
        self.init_req_max_new_tokens(req)

1321
        # Validate prompt length
1322
1323
1324
1325
1326
1327
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
1328
            req.set_finish_with_abort(error_msg)
1329
            self._add_request_to_queue(req)
1330
            return
1331

1332
        # Copy more attributes
1333
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
1334
            # By default, only return the logprobs for output tokens
1335
1336
1337
1338
1339
1340
1341
            # For prefill-only requests with logprob_start_len == -1, set logprob_start_len beyond input sequence
            # to skip input logprob computation entirely
            if req.is_prefill_only:
                req.logprob_start_len = len(req.origin_input_ids)
            else:
                # TODO: For text generation, evaluate setting logprob_start_len to len(req.origin_input_ids) as well
                req.logprob_start_len = len(req.origin_input_ids) - 1
1342
1343
1344
        else:
            req.logprob_start_len = recv_req.logprob_start_len

1345
1346
1347
        if not req.is_prefill_only and req.logprob_start_len >= len(
            req.origin_input_ids
        ):
1348
            error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
1349
            req.logprob_start_len = len(req.origin_input_ids) - 1
1350
            req.set_finish_with_abort(error_msg)
1351
1352
1353
            self._add_request_to_queue(req)
            return

1354
1355
1356
1357
1358
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
1359
            or req.sampling_params.ebnf is not None
1360
            or req.sampling_params.structural_tag is not None
1361
1362
1363
1364
1365
1366
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
1367
1368
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
1369
1370
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
1371

1372
1373
1374
1375
1376
            value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
            req.grammar = value

            if not cache_hit:
                req.grammar_key = key
1377
                add_to_grammar_queue = True
1378
1379
1380
1381
            else:
                if value is INVALID_GRAMMAR_OBJ:  # We hit a cached invalid grammar.
                    error_msg = f"Invalid grammar request with cache hit: {key=}"
                    req.set_finish_with_abort(error_msg)
1382
1383

        if add_to_grammar_queue:
1384
            req.queue_time_start = time.perf_counter()
1385
1386
            self.grammar_queue.append(req)
        else:
1387
1388
            self._add_request_to_queue(req)

1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
    def handle_batch_generate_request(
        self,
        recv_req: BatchTokenizedGenerateReqInput,
    ):
        """Handle optimized batch generate request."""
        logger.debug(f"Processing batch generate request with {len(recv_req)} requests")

        # Process each request in the batch
        for tokenized_req in recv_req:
            self.handle_generate_request(tokenized_req)

1400
    def _add_request_to_queue(self, req: Req):
1401
        req.queue_time_start = time.perf_counter()
Byron Hsu's avatar
Byron Hsu committed
1402
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
1403
            self._prefetch_kvcache(req)
Byron Hsu's avatar
Byron Hsu committed
1404
1405
1406
            self.disagg_prefill_bootstrap_queue.add(
                req, self.model_config.num_key_value_heads
            )
Byron Hsu's avatar
Byron Hsu committed
1407
1408
1409
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.add(req)
        else:
1410
1411
1412
            self._set_or_validate_priority(req)
            if self._abort_on_queued_limit(req):
                return
1413
            self._prefetch_kvcache(req)
Byron Hsu's avatar
Byron Hsu committed
1414
            self.waiting_queue.append(req)
1415
            trace_slice_end("process req", req.rid, auto_next_anon=True)
Byron Hsu's avatar
Byron Hsu committed
1416

1417
1418
1419
    def _prefetch_kvcache(self, req: Req):
        if self.enable_hicache_storage:
            req.init_next_round_input(self.tree_cache)
1420
1421
1422
1423
1424
            if req.last_node.backuped:
                # only to initiate the prefetch if the last node is backuped
                # otherwise, the allocated GPU memory must be locked for integrity
                last_hash = req.last_host_node.get_last_hash_value()
                matched_len = len(req.prefix_indices) + req.host_hit_length
1425
1426
1427
1428
1429
                new_input_tokens = req.fill_ids[matched_len:]
                self.tree_cache.prefetch_from_storage(
                    req.rid, req.last_host_node, new_input_tokens, last_hash
                )

1430
    def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
1431
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
Byron Hsu's avatar
Byron Hsu committed
1432
1433
1434
            self.disagg_prefill_bootstrap_queue.extend(
                reqs, self.model_config.num_key_value_heads
            )
1435
1436
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            # If this is a decode server, we put the request to the decode pending prealloc queue
1437
            self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
Byron Hsu's avatar
Byron Hsu committed
1438
        else:
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
            for req in reqs:
                self._set_or_validate_priority(req)
                if not self._abort_on_queued_limit(req):
                    self.waiting_queue.append(req)

    def _set_or_validate_priority(self, req: Req):
        """Set the default priority value, or abort the request based on the priority scheduling mode."""
        if self.enable_priority_scheduling and req.priority is None:
            if self.schedule_low_priority_values_first:
                req.priority = sys.maxsize
            else:
                req.priority = -sys.maxsize - 1
        elif not self.enable_priority_scheduling and req.priority is not None:
            abort_req = AbortReq(
                req.rid,
                finished_reason={
                    "type": "abort",
                    "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
                    "message": "Using priority is disabled for this server. Please send a new request without a priority.",
                },
            )
            self.send_to_tokenizer.send_pyobj(abort_req)

    def _abort_on_queued_limit(self, recv_req: Req) -> bool:
        """Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
        if (
            self.max_queued_requests is None
            or len(self.waiting_queue) + 1 <= self.max_queued_requests
        ):
            return False

        # Reject the incoming request by default.
        req_to_abort = recv_req
        message = "The request queue is full."
        if self.enable_priority_scheduling:
            # With priority scheduling, consider aboritng an existing request based on the priority.
            # direction = 1  => smaller number = higher priority; -1 => larger number = higher priority.
            # max(...) + (direction * priority, queue_time_start) picks the least-preferred request.
            # Tie: later queue_time_start (newer) is evicted first. Preempt only if strictly better.
            direction = 1 if self.schedule_low_priority_values_first else -1
            key_fn = lambda item: (
                direction * item[1].priority,
                item[1].queue_time_start,
            )
            idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
            abort_existing_req = (
                direction * recv_req.priority < direction * candidate_req.priority
            )
            if abort_existing_req:
                self.waiting_queue.pop(idx)
                req_to_abort = candidate_req
                message = "The request is aborted by a higher priority request."

        self.send_to_tokenizer.send_pyobj(
            AbortReq(
                req_to_abort.rid,
                finished_reason={
                    "type": "abort",
                    "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
                    "message": message,
                },
            )
        )
        return req_to_abort.rid == recv_req.rid
1503
1504
1505

    def handle_embedding_request(
        self,
1506
        recv_req: TokenizedEmbeddingReqInput,
1507
1508
1509
1510
1511
1512
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
woodx's avatar
woodx committed
1513
            token_type_ids=recv_req.token_type_ids,
1514
            priority=recv_req.priority,
1515
1516
1517
        )
        req.tokenizer = self.tokenizer

1518
1519
        # Handle multimodal inputs
        if recv_req.image_inputs is not None:
Mick's avatar
Mick committed
1520
            image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
1521
1522
1523
1524
1525
1526
1527
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
            req.origin_input_ids = self.pad_input_ids_func(
                req.origin_input_ids, image_inputs
            )
            req.extend_image_inputs(image_inputs)

            if len(req.origin_input_ids) >= self.max_req_input_len:
1528
1529
1530
1531
1532
                req.set_finish_with_abort(
                    error_msg=(
                        "Multimodal prompt is too long after expanding multimodal tokens. "
                        f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                    )
1533
                )
1534
                self._add_request_to_queue(req)
1535
1536
                return

1537
        # Validate prompts length
1538
        error_msg = validate_input_length(
1539
1540
1541
1542
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
1543
        if error_msg:
1544
            self._add_request_to_queue(req)
1545
            return
1546

1547
1548
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
1549
        self._add_request_to_queue(req)
1550

1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
    def handle_batch_embedding_request(
        self,
        recv_req: BatchTokenizedEmbeddingReqInput,
    ):
        """Handle optimized batch embedding request."""
        logger.debug(
            f"Processing batch embedding request with {len(recv_req)} requests"
        )

        # Process each request in the batch
        for tokenized_req in recv_req:
            self.handle_embedding_request(tokenized_req)

1564
1565
1566
1567
1568
    def self_check_during_idle(self):
        self.check_memory()
        self.check_tree_cache()
        self.new_token_ratio = self.init_new_token_ratio
        self.maybe_sleep_on_idle()
1569

Lianmin Zheng's avatar
Lianmin Zheng committed
1570
    def check_memory(self):
Hanming Lu's avatar
Hanming Lu committed
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
        if self.is_hybrid:
            (
                full_num_used,
                swa_num_used,
                _,
                _,
                full_available_size,
                full_evictable_size,
                swa_available_size,
                swa_evictable_size,
            ) = self._get_swa_token_info()
            memory_leak = full_num_used != 0 or swa_num_used != 0
            token_msg = (
                f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
                f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
            )
tarinkk's avatar
tarinkk committed
1587
        else:
Hanming Lu's avatar
Hanming Lu committed
1588
1589
1590
            _, _, available_size, evictable_size = self._get_token_info()
            protected_size = self.tree_cache.protected_size()
            memory_leak = (available_size + evictable_size) != (
1591
1592
1593
                # self.max_total_num_tokens
                # if not self.enable_hierarchical_cache
                # else self.max_total_num_tokens - protected_size
Hanming Lu's avatar
Hanming Lu committed
1594
                self.max_total_num_tokens
1595
                - protected_size
Lianmin Zheng's avatar
Lianmin Zheng committed
1596
            )
Hanming Lu's avatar
Hanming Lu committed
1597
1598
1599
1600
            token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"

        if memory_leak:
            msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
Lianmin Zheng's avatar
Lianmin Zheng committed
1601
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1602

1603
1604
1605
1606
1607
1608
1609
1610
        if self.disaggregation_mode == DisaggregationMode.DECODE:
            req_total_size = (
                self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
            )
        else:
            req_total_size = self.req_to_token_pool.size

        if len(self.req_to_token_pool.free_slots) != req_total_size:
1611
            msg = (
1612
                "req_to_token_pool memory leak detected!"
1613
1614
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1615
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1616
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1617

1618
1619
        if (
            self.enable_metrics
1620
            and self.current_scheduler_metrics_enabled()
1621
            and time.perf_counter() > self.metrics_collector.last_log_time + 30
1622
1623
        ):
            # During idle time, also collect metrics every 30 seconds.
Hanming Lu's avatar
Hanming Lu committed
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
            if self.is_hybrid:
                (
                    full_num_used,
                    swa_num_used,
                    full_token_usage,
                    swa_token_usage,
                    _,
                    _,
                    _,
                    _,
                ) = self._get_swa_token_info()
                num_used = max(full_num_used, swa_num_used)
                token_usage = max(full_token_usage, swa_token_usage)
            else:
                num_used, token_usage, _, _ = self._get_token_info()
Lianmin Zheng's avatar
Lianmin Zheng committed
1639
            num_running_reqs = len(self.running_batch.reqs)
1640
1641
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
Hanming Lu's avatar
Hanming Lu committed
1642
            self.stats.token_usage = round(token_usage, 2)
1643
1644
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
1645
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
            if self.disaggregation_mode == DisaggregationMode.PREFILL:
                self.stats.num_prefill_prealloc_queue_reqs = len(
                    self.disagg_prefill_bootstrap_queue.queue
                )
                self.stats.num_prefill_inflight_queue_reqs = len(
                    self.disagg_prefill_inflight_queue
                )
            if self.disaggregation_mode == DisaggregationMode.DECODE:
                self.stats.num_decode_prealloc_queue_reqs = len(
                    self.disagg_decode_prealloc_queue.queue
                )
                self.stats.num_decode_transfer_queue_reqs = len(
                    self.disagg_decode_transfer_queue.queue
                )
1660
            self.metrics_collector.log_stats(self.stats)
1661
        self._publish_kv_events()
1662

Hanming Lu's avatar
Hanming Lu committed
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
    def check_tree_cache(self):
        if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache):
            self.tree_cache.sanity_check()

    def _get_token_info(self):
        available_size = self.token_to_kv_pool_allocator.available_size()
        evictable_size = self.tree_cache.evictable_size()
        num_used = self.max_total_num_tokens - (available_size + evictable_size)
        token_usage = num_used / self.max_total_num_tokens
        return num_used, token_usage, available_size, evictable_size

    def _get_swa_token_info(self):
        full_available_size = self.token_to_kv_pool_allocator.full_available_size()
        full_evictable_size = self.tree_cache.full_evictable_size()
        swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
        swa_evictable_size = self.tree_cache.swa_evictable_size()
        full_num_used = self.full_tokens_per_layer - (
            full_available_size + full_evictable_size
        )
        swa_num_used = self.swa_tokens_per_layer - (
            swa_available_size + swa_evictable_size
        )
        full_token_usage = full_num_used / self.full_tokens_per_layer
        swa_token_usage = swa_num_used / self.swa_tokens_per_layer
        return (
            full_num_used,
            swa_num_used,
            full_token_usage,
            swa_token_usage,
            full_available_size,
            full_evictable_size,
            swa_available_size,
            swa_evictable_size,
        )

1698
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1699
        # Merge the prefill batch into the running batch
1700
1701
1702
1703
1704
        chunked_req_to_exclude = set()
        if self.chunked_req:
            # Move the chunked request out of the batch so that we can merge
            # only finished requests to running_batch.
            chunked_req_to_exclude.add(self.chunked_req)
1705
            self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
1706
            # chunked request keeps its rid but will get a new req_pool_idx
Yi Zhang's avatar
Yi Zhang committed
1707
1708
1709
1710
1711
1712
            if self.tp_worker.worker.model_runner.is_hybrid_gdn:
                self.req_to_token_pool.free(
                    self.chunked_req.req_pool_idx, free_mamba_cache=False
                )
            else:
                self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
Lianmin Zheng's avatar
Lianmin Zheng committed
1713
        if self.last_batch and self.last_batch.forward_mode.is_extend():
1714
1715
1716
1717
            if self.last_batch.chunked_req is not None:
                # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
                # We need to discard it.
                chunked_req_to_exclude.add(self.last_batch.chunked_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1718

1719
            # Filter batch
1720
            last_bs = self.last_batch.batch_size()
1721
1722
1723
            self.last_batch.filter_batch(
                chunked_req_to_exclude=list(chunked_req_to_exclude)
            )
1724
            if self.last_batch.batch_size() < last_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1725
                self.running_batch.batch_is_full = False
1726

1727
1728
1729
            # Merge the new batch into the running batch.
            # For prefill-only batch, we can avoid going through decoding step.
            if not self.last_batch.is_empty() and not self.last_batch.is_prefill_only:
Lianmin Zheng's avatar
Lianmin Zheng committed
1730
                if self.running_batch.is_empty():
1731
1732
                    self.running_batch = self.last_batch
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1733
                    # Merge running_batch with prefill batch
1734
                    self.running_batch.merge_batch(self.last_batch)
1735

1736
        new_batch = self.get_new_batch_prefill()
1737

1738
1739
1740
1741
1742
        need_dp_attn_preparation = require_mlp_sync(self.server_args)

        if need_dp_attn_preparation and not self.spec_algorithm.is_none():
            # In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
            # We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
1743
            new_batch = self.prepare_mlp_sync_batch(new_batch)
1744
1745
1746
            need_dp_attn_preparation = new_batch is None

        if new_batch is not None:
1747
1748
1749
1750
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
1751
            if not self.running_batch.is_empty():
1752
                self.running_batch = self.update_running_batch(self.running_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1753
1754
1755
                ret = self.running_batch if not self.running_batch.is_empty() else None
            else:
                ret = None
1756

1757
1758
        # Handle DP attention
        if need_dp_attn_preparation:
1759
            self.maybe_handle_dp_balance_data()
1760
            ret = self.prepare_mlp_sync_batch(ret)
1761
1762

        return ret
1763

1764
1765
1766
1767
1768
1769
    def get_num_allocatable_reqs(self, running_bs):
        res = global_server_args_dict["max_micro_batch_size"] - running_bs
        if self.pp_size > 1:
            res = min(res, self.req_to_token_pool.available_size())
        return res

Lianmin Zheng's avatar
Lianmin Zheng committed
1770
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
1771
        # Check if the grammar is ready in the grammar queue
1772
        if self.grammar_queue:
1773
            self.move_ready_grammar_requests()
1774

1775
1776
1777
1778
        if self.try_preemption:
            # Reset batch_is_full to try preemption with a prefill adder.
            self.running_batch.batch_is_full = False

Lianmin Zheng's avatar
Lianmin Zheng committed
1779
1780
        # Handle the cases where prefill is not allowed
        if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1781
            self.running_batch.batch_is_full or len(self.waiting_queue) == 0
1782
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1783
1784
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1785
        running_bs = len(self.running_batch.reqs)
1786
        # Ignore the check if self.chunked_req is not None.
1787
1788
1789
1790
        # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
        # as the space for the chunked request has just been released.
        # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
        # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
1791
1792
1793
1794
1795
        if (
            self.get_num_allocatable_reqs(running_bs) <= 0
            and not self.chunked_req
            and not self.try_preemption
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1796
            self.running_batch.batch_is_full = True
1797
1798
            return None

1799
        if self.enable_hierarchical_cache:
1800
            self.tree_cache.check_hicache_events()
1801

1802
        # Get priority queue
1803
        self.policy.calc_priority(self.waiting_queue)
1804

Lianmin Zheng's avatar
Lianmin Zheng committed
1805
        # Prefill policy
1806
        adder = PrefillAdder(
1807
            self.page_size,
1808
            self.tree_cache,
1809
            self.token_to_kv_pool_allocator,
1810
1811
1812
1813
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1814
            running_bs if self.is_mixed_chunk else 0,
1815
            self.priority_scheduling_preemption_threshold,
1816
1817
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1818
        if self.chunked_req is not None:
1819
1820
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1821

1822
        if self.enable_lora:
1823
            lora_set = set([req.lora_id for req in self.running_batch.reqs])
Lianmin Zheng's avatar
Lianmin Zheng committed
1824

1825
        # Get requests from the waiting queue to a new prefill batch
1826
        for req in self.waiting_queue:
1827
1828
1829
1830
1831

            if self.enable_lora and not self.tp_worker.can_run_lora_batch(
                lora_set
                | set([req.lora_id for req in adder.can_run_list])
                | set([req.lora_id])
1832
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1833
                self.running_batch.batch_is_full = True
1834
1835
                break

1836
            running_bs = len(self.running_batch.reqs) - len(adder.preempt_list)
1837
            if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
Lianmin Zheng's avatar
Lianmin Zheng committed
1838
                self.running_batch.batch_is_full = True
Byron Hsu's avatar
Byron Hsu committed
1839
1840
1841
1842
1843
            if self.disaggregation_mode == DisaggregationMode.PREFILL:
                # In prefill mode, prealloc queue and transfer queue can also take memory,
                # so we need to check if the available size for the actual available size.
                if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
                    self.running_batch.batch_is_full = True
1844
1845
1846
1847
1848

            if self.running_batch.batch_is_full:
                if not self.try_preemption:
                    break
                if not adder.preempt_to_schedule(req, self.server_args):
Byron Hsu's avatar
Byron Hsu committed
1849
1850
                    break

1851
            if self.enable_hicache_storage:
pansicheng's avatar
pansicheng committed
1852
1853
1854
1855
                prefetch_done = self.tree_cache.check_prefetch_progress(req.rid)
                if not prefetch_done:
                    # skip staging requests that are ongoing prefetch
                    continue
1856

1857
            req.init_next_round_input(self.tree_cache)
1858
1859
1860
1861
1862
            res = adder.add_one_req(
                req,
                has_chunked_req=(self.chunked_req is not None),
                truncation_align_size=self.truncation_align_size,
            )
1863

1864
1865
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1866
1867
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
Lianmin Zheng's avatar
Lianmin Zheng committed
1868
1869
                        self.running_batch.batch_is_full = len(
                            adder.can_run_list
1870
                        ) > 0 or (not self.running_batch.is_empty())
1871
                    else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1872
                        self.running_batch.batch_is_full = True
1873
1874
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1875
        # Update waiting queue
1876
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1877
1878
        if len(can_run_list) == 0:
            return None
1879
1880
1881
1882

        if self.enable_metrics:
            # only record queue time when enable_metrics is True to avoid overhead
            for req in can_run_list:
1883
                req.queue_time_end = time.perf_counter()
1884
                req.add_latency(RequestStage.PREFILL_WAITING)
1885

Lianmin Zheng's avatar
Lianmin Zheng committed
1886
1887
1888
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1889
1890
        if adder.preempt_list:
            self._extend_requests_to_queue(adder.preempt_list)
1891

1892
1893
1894
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1895

1896
1897
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1898

1899
        # Print stats
1900
        if self.current_scheduler_metrics_enabled():
1901
            self.log_prefill_stats(adder, can_run_list, running_bs)
1902

Lianmin Zheng's avatar
Lianmin Zheng committed
1903
        # Create a new batch
1904
1905
1906
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
1907
            self.token_to_kv_pool_allocator,
1908
            self.tree_cache,
1909
            self.model_config,
1910
            self.enable_overlap,
1911
            self.spec_algorithm,
1912
            chunked_req=self.chunked_req,
1913
        )
1914
1915
        if self.enable_hierarchical_cache:
            # todo (zhiqiang): disable cuda graph execution if hicache loading triggered
1916
1917
1918
            new_batch.hicache_consumer_index = (
                self.tree_cache.ready_to_load_host_cache()
            )
1919

1920
        new_batch.prepare_for_extend()
1921

Lianmin Zheng's avatar
Lianmin Zheng committed
1922
        # Mixed-style chunked prefill
1923
1924
        if (
            self.is_mixed_chunk
Lianmin Zheng's avatar
Lianmin Zheng committed
1925
            and not self.running_batch.is_empty()
1926
1927
1928
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1929
1930
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1931
                self.running_batch.prepare_for_decode()
1932
1933
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1934
1935
1936
            self.running_batch = ScheduleBatch(
                reqs=[], batch_is_full=self.running_batch.batch_is_full
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1937
1938
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1939
1940
1941

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1942
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1943
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1944
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1945

1946
1947
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1948
1949
            batch.batch_is_full = False
            return batch
1950

Lianmin Zheng's avatar
Lianmin Zheng committed
1951
        # Check if decode out of memory
1952
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1953
            TEST_RETRACT and batch.batch_size() > 10
1954
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1955
1956
            old_ratio = self.new_token_ratio

1957
            retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
1958
            num_retracted_reqs = len(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1959
            self.new_token_ratio = new_token_ratio
1960

Lianmin Zheng's avatar
Lianmin Zheng committed
1961
            logger.info(
1962
                "KV cache pool is full. Retract requests. "
1963
                f"#retracted_reqs: {num_retracted_reqs}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
1964
1965
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
1966

1967
            self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
1968
            self.total_retracted_reqs += num_retracted_reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1969
1970
        else:
            self.new_token_ratio = max(
1971
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1972
1973
1974
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1975
        if batch.batch_size() < initial_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1976
            batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1977
1978

        # Update batch tensors
1979
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1980
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1981

1982
1983
1984
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1985
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1986
1987
        self.forward_ct += 1

1988
1989
        # Whether to run the profiler
        self._profile_batch_predicate(batch)
1990
1991
1992
1993
        if self.forward_sleep_time is not None:
            logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
            time.sleep(self.forward_sleep_time)

1994
        # Run forward
1995
        if self.is_generation:
1996
1997
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
1998

1999
                if self.pp_group.is_last_rank:
2000
                    logits_output, next_token_ids, can_run_cuda_graph = (
2001
2002
2003
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
                else:
2004
                    pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
2005
2006
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
2007
                bid = model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
2008
            else:
2009
2010
2011
                (
                    logits_output,
                    next_token_ids,
2012
                    bid,
2013
                    num_accepted_tokens,
2014
                    can_run_cuda_graph,
2015
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
2016
2017
2018
                bs = batch.batch_size()
                self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
                self.spec_num_total_forward_ct += bs
2019
                self.num_generated_tokens += num_accepted_tokens
2020
2021
2022

            if self.pp_group.is_last_rank:
                batch.output_ids = next_token_ids
2023

2024
2025
2026
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
2027
            if batch.return_logprob or self.spec_algorithm.is_eagle():
2028
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
2029
2030
2031
            else:
                extend_input_len_per_req = None
            if batch.return_logprob:
2032
2033
2034
2035
2036
2037
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_logprob_start_len_per_req = None

2038
            ret = GenerationBatchResult(
2039
2040
2041
2042
2043
2044
2045
                logits_output=logits_output if self.pp_group.is_last_rank else None,
                pp_hidden_states_proxy_tensors=(
                    pp_hidden_states_proxy_tensors
                    if not self.pp_group.is_last_rank
                    else None
                ),
                next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
2046
2047
                extend_input_len_per_req=extend_input_len_per_req,
                extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
2048
                bid=bid,
2049
                can_run_cuda_graph=can_run_cuda_graph,
2050
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
2051
2052
2053
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
2054
2055
2056
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
2057
        return ret
Chayenne's avatar
Chayenne committed
2058

2059
2060
2061
2062
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
2063
        launch_done: Optional[threading.Event] = None,
2064
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
2065
        if batch.forward_mode.is_decode():
2066
            self.process_batch_result_decode(batch, result, launch_done)
2067
2068
2069
2070
2071
2072
2073
2074
            for req in batch.reqs:
                trace_slice(
                    "decode loop",
                    req.rid,
                    auto_next_anon=not req.finished(),
                    thread_finish_flag=req.finished(),
                )

2075
        elif batch.forward_mode.is_extend():
2076
            self.process_batch_result_prefill(batch, result, launch_done)
2077
2078
2079
2080
2081
2082
2083
            for req in batch.reqs:
                trace_slice(
                    "prefill",
                    req.rid,
                    auto_next_anon=not req.finished(),
                    thread_finish_flag=req.finished(),
                )
2084
2085
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
2086
                self.tp_worker.resolve_last_batch_result(launch_done)
2087
                self.set_next_batch_sampling_info_done(batch)
2088
        elif batch.forward_mode.is_dummy_first():
2089
            self.set_next_batch_sampling_info_done(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
2090

2091
2092
2093
        self.maybe_send_health_check_signal()

    def maybe_send_health_check_signal(self):
2094
2095
2096
2097
2098
2099
2100
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

2101
2102
    def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
        return self.prepare_mlp_sync_batch_raw(
2103
2104
2105
            local_batch,
            dp_size=self.server_args.dp_size,
            attn_tp_size=self.attn_tp_size,
2106
            tp_group=self.tp_group,
2107
2108
2109
2110
            get_idle_batch=self.get_idle_batch,
            disable_cuda_graph=self.server_args.disable_cuda_graph,
            spec_algorithm=self.spec_algorithm,
            speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
2111
            require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
2112
            disable_overlap_schedule=self.server_args.disable_overlap_schedule,
2113
2114
2115
        )

    @staticmethod
2116
    def prepare_mlp_sync_batch_raw(
2117
2118
2119
        local_batch: ScheduleBatch,
        dp_size,
        attn_tp_size: int,
2120
        tp_group,
2121
2122
2123
2124
        get_idle_batch,
        disable_cuda_graph: bool,
        spec_algorithm,
        speculative_num_draft_tokens,
2125
        require_mlp_tp_gather: bool,
2126
        disable_overlap_schedule: bool,
2127
    ):
2128
2129
2130
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
2131
            num_tokens_for_logprob = 0
2132
2133
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
2134
            num_tokens_for_logprob = num_tokens
2135
2136
        else:
            num_tokens = local_batch.extend_num_tokens
2137
            num_tokens_for_logprob = sum(
Lianmin Zheng's avatar
Lianmin Zheng committed
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
                [
                    # We should have at least 1 token for sample in every case.
                    max(extend_len - logprob_start_len, 1)
                    for logprob_start_len, extend_len in zip(
                        local_batch.extend_logprob_start_lens, local_batch.extend_lens
                    )
                ]
            )

        if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
            can_cuda_graph = 1
        else:
            can_cuda_graph = 0

        is_extend_in_batch = (
            local_batch.forward_mode.is_extend() if local_batch else False
        )
2155
2156

        tbo_preparer = TboDPAttentionPreparer()
2157
2158
2159
2160
2161
2162
        if disable_overlap_schedule:
            group = tp_group.device_group
            device = tp_group.device
        else:
            group = tp_group.cpu_group
            device = "cpu"
2163

Lianmin Zheng's avatar
Lianmin Zheng committed
2164
2165
2166
2167
        local_info = torch.tensor(
            [
                num_tokens,
                can_cuda_graph,
2168
                num_tokens_for_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
2169
                is_extend_in_batch,
2170
2171
2172
                *tbo_preparer.prepare_all_gather(
                    local_batch,
                ),
Lianmin Zheng's avatar
Lianmin Zheng committed
2173
2174
            ],
            dtype=torch.int64,
2175
            device=device,
Lianmin Zheng's avatar
Lianmin Zheng committed
2176
2177
        )
        global_info = torch.empty(
2178
            (dp_size, attn_tp_size, 6),
Lianmin Zheng's avatar
Lianmin Zheng committed
2179
            dtype=torch.int64,
2180
            device=device,
Lianmin Zheng's avatar
Lianmin Zheng committed
2181
        )
2182
        torch.distributed.all_gather_into_tensor(
Lianmin Zheng's avatar
Lianmin Zheng committed
2183
2184
            global_info.flatten(),
            local_info,
2185
            group=group,
2186
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
2187
2188
2189
2190
        global_num_tokens = global_info[:, 0, 0].tolist()
        can_cuda_graph = min(global_info[:, 0, 1].tolist())
        global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
        is_extend_in_batch = global_info[:, 0, 3].tolist()
2191

2192
2193
2194
2195
        tbo_split_seq_index, global_forward_mode = tbo_preparer.compute_output(
            global_info[:, :, 4:6]
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
2196
        if local_batch is None and max(global_num_tokens) > 0:
2197
            local_batch = get_idle_batch()
2198
2199

        if local_batch is not None:
2200
            # TODO: handle the case when moe_dense_tp_size != 1
2201
            if not require_mlp_tp_gather:
2202
2203
2204
2205
2206
2207
2208
                local_batch.global_num_tokens = [num_tokens]
                local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
            else:
                local_batch.global_num_tokens = global_num_tokens
                local_batch.global_num_tokens_for_logprob = (
                    global_num_tokens_for_logprob
                )
2209
            local_batch.is_extend_in_batch = any(is_extend_in_batch)
2210
2211
            local_batch.tbo_split_seq_index = tbo_split_seq_index
            local_batch.global_forward_mode = global_forward_mode
2212

2213
            # Check forward mode for cuda graph
2214
            if not disable_cuda_graph:
Lianmin Zheng's avatar
Lianmin Zheng committed
2215
                local_batch.can_run_dp_cuda_graph = can_cuda_graph
2216

2217
        return local_batch
2218
2219
2220
2221
2222

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
2223
            self.token_to_kv_pool_allocator,
2224
2225
2226
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
2227
            self.spec_algorithm,
2228
2229
2230
2231
        )
        idle_batch.prepare_for_idle()
        return idle_batch

2232
2233
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
2234

2235
        num_ready_reqs = 0
2236
        num_timeout_reqs = 0
2237
2238
        for req in self.grammar_queue:
            try:
2239
2240
2241
                if req.finished():  # It is aborted by AbortReq
                    num_ready_reqs += 1
                    continue
2242
                req.grammar = req.grammar.result(timeout=0.03)
2243
2244
2245
2246
2247
                self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
                if req.grammar is INVALID_GRAMMAR_OBJ:
                    req.set_finish_with_abort(
                        f"Invalid grammar request: {req.grammar_key=}"
                    )
2248
2249
                num_ready_reqs += 1
            except futures._base.TimeoutError:
2250
                req.grammar_wait_ct += 1
2251
2252
                # NOTE(lianmin): this timeout is the waiting time of the above line. It is
                # not the waiting time from it enters the grammar queue.
2253
                if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
2254
                    num_timeout_reqs = 1
2255
2256
                break

2257
        if self.server_args.enable_dp_attention:
2258
2259
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
2260
        else:
2261
2262
2263
2264
2265
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
2266
            tensor = torch.tensor([num_ready_reqs, num_timeout_reqs], dtype=torch.int32)
2267
2268
2269
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
2270
            num_ready_reqs_max, num_timeout_reqs_max = tensor.tolist()
2271

2272
            for i in range(num_ready_reqs, num_ready_reqs_max):
2273
                req = self.grammar_queue[i]
2274
2275
                if req.finished():  # It is aborted by AbortReq
                    continue
2276
                req.grammar = req.grammar.result()
2277
2278
2279
2280
2281
2282
2283
2284
                self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
                if req.grammar is INVALID_GRAMMAR_OBJ:
                    req.set_finish_with_abort(
                        f"Invalid grammar request: {req.grammar_key=}"
                    )
        else:
            num_ready_reqs_max = num_ready_reqs
            num_timeout_reqs_max = num_timeout_reqs
2285

2286
2287
2288
2289
2290
2291
2292
        for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
            req = self.grammar_queue[i]
            req.grammar.cancel()
            error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
            req.set_finish_with_abort(error_msg)
            self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
        num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
2293

2294
        self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
2295
2296
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

2297
2298
2299
    def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
        if batch.next_batch_sampling_info:
            if batch.next_batch_sampling_info.grammars is not None:
2300
2301
2302
                if self.disaggregation_mode != DisaggregationMode.PREFILL:
                    batch.next_batch_sampling_info.update_regex_vocab_mask()
                    self.current_stream.synchronize()
2303
2304
            batch.next_batch_sampling_info.sampling_info_done.set()

2305
2306
2307
    def watchdog_thread(self):
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
        self.watchdog_last_forward_ct = 0
2308
        self.watchdog_last_time = time.perf_counter()
2309
2310

        while True:
2311
            current = time.perf_counter()
2312
2313
2314
2315
2316
2317
2318
2319
2320
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if current > self.watchdog_last_time + self.watchdog_timeout:
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)

Lianmin Zheng's avatar
Lianmin Zheng committed
2321
2322
        if not disable_request_logging():
            # Print batch size and memory pool info to check whether there are de-sync issues.
Hanming Lu's avatar
Hanming Lu committed
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
            if self.is_hybrid:
                (
                    _,
                    _,
                    _,
                    _,
                    full_available_size,
                    full_evictable_size,
                    swa_available_size,
                    swa_evictable_size,
                ) = self._get_swa_token_info()
                info_msg = (
                    f"{full_available_size=}, "
                    f"{full_evictable_size=}, "
                    f"{swa_available_size=}, "
                    f"{swa_evictable_size=}, "
                )
            else:
                _, _, available_size, evictable_size = self._get_token_info()
                info_msg = f"{available_size=}, " f"{evictable_size=}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
2343
2344
2345
            logger.error(
                f"{self.cur_batch.batch_size()=}, "
                f"{self.cur_batch.reqs=}, "
Hanming Lu's avatar
Hanming Lu committed
2346
                f"{info_msg}"
Lianmin Zheng's avatar
Lianmin Zheng committed
2347
2348
            )

2349
        pyspy_dump_schedulers()
Lianmin Zheng's avatar
Lianmin Zheng committed
2350
        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
2351
2352
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
2353
2354

        # Wait for some time so that the parent process can print the error.
2355
2356
2357
        time.sleep(5)
        self.parent_process.send_signal(signal.SIGQUIT)

2358
2359
2360
    def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
        success = self.flush_cache()
        return FlushCacheReqOutput(success=success)
2361

2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
    def clear_hicache_storage_wrapped(self, recv_req: ClearHiCacheReqInput):
        if self.enable_hierarchical_cache:
            self.tree_cache.clear_storage_backend()
            logger.info("Hierarchical cache cleared successfully!")
            if_success = True
        else:
            logging.warning("Hierarchical cache is not enabled.")
            if_success = False
        return ClearHiCacheReqOutput(success=if_success)

2372
    def flush_cache(self):
2373
        """Flush the memory pool and cache."""
2374
2375
2376
2377
2378
        if (
            len(self.waiting_queue) == 0
            and self.running_batch.is_empty()
            and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
        ):
2379
2380
            self.cur_batch = None
            self.last_batch = None
2381
            self.tree_cache.reset()
2382
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
2383
                self.grammar_backend.reset()
2384
            self.req_to_token_pool.clear()
2385
            self.token_to_kv_pool_allocator.clear()
2386

2387
2388
            if self.draft_worker:
                self.draft_worker.clear_cache_pool()
2389
2390
2391
2392
2393

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
2394
2395
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
2396
2397
2398
2399
2400
2401
2402
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
2403
                f"#running-req: {len(self.running_batch.reqs)}"
2404
2405
2406
2407
            )
            if_success = False
        return if_success

2408
    def get_load(self, recv_req: GetLoadReqInput = None) -> GetLoadReqOutput:
Liangsheng Yin's avatar
Liangsheng Yin committed
2409
        # TODO(lsyin): use dynamically maintained num_waiting_tokens
2410

Hanming Lu's avatar
Hanming Lu committed
2411
        if self.is_hybrid:
2412
            num_tokens_full = (
Hanming Lu's avatar
Hanming Lu committed
2413
2414
2415
2416
                self.full_tokens_per_layer
                - self.token_to_kv_pool_allocator.full_available_size()
                - self.tree_cache.full_evictable_size()
            )
2417
            num_tokens_swa = (
Hanming Lu's avatar
Hanming Lu committed
2418
2419
2420
2421
                self.swa_tokens_per_layer
                - self.token_to_kv_pool_allocator.swa_available_size()
                - self.tree_cache.swa_evictable_size()
            )
2422
            num_tokens = max(num_tokens_full, num_tokens_swa)
Hanming Lu's avatar
Hanming Lu committed
2423
        else:
2424
            num_tokens = (
Hanming Lu's avatar
Hanming Lu committed
2425
2426
2427
2428
                self.max_total_num_tokens
                - self.token_to_kv_pool_allocator.available_size()
                - self.tree_cache.evictable_size()
            )
2429
2430
2431
2432

        # Tokens in waiting queue, bootstrap queue, prealloc queue
        num_tokens += sum(len(req.origin_input_ids) for req in self.waiting_queue)
        num_waiting_reqs = len(self.waiting_queue)
Liangsheng Yin's avatar
Liangsheng Yin committed
2433
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
2434
            num_tokens += sum(
Liangsheng Yin's avatar
Liangsheng Yin committed
2435
2436
2437
                len(req.origin_input_ids)
                for req in self.disagg_prefill_bootstrap_queue.queue
            )
2438
            num_waiting_reqs += len(self.disagg_prefill_bootstrap_queue.queue)
Liangsheng Yin's avatar
Liangsheng Yin committed
2439
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
2440
            num_tokens += sum(
Liangsheng Yin's avatar
Liangsheng Yin committed
2441
2442
2443
                len(req.req.origin_input_ids)
                for req in self.disagg_decode_prealloc_queue.queue
            )
2444
            num_waiting_reqs += len(self.disagg_decode_prealloc_queue.queue)
Liangsheng Yin's avatar
Liangsheng Yin committed
2445

2446
2447
2448
2449
2450
2451
        return GetLoadReqOutput(
            dp_rank=self.dp_rank,
            num_reqs=len(self.running_batch.reqs) + num_waiting_reqs,
            num_waiting_reqs=num_waiting_reqs,
            num_tokens=num_tokens,
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
2452

2453
2454
2455
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
2456
2457
2458
2459
2460
2461
2462
2463
2464
        ret["memory_usage"] = {
            "weight": round(
                self.tp_worker.worker.model_runner.weight_load_mem_usage, 2
            ),
            "kvcache": round(
                self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
            ),
            "token_capacity": int(self.max_total_num_tokens),
        }
2465

2466
2467
2468
        ret["memory_usage"]["graph"] = round(
            self.tp_worker.worker.model_runner.graph_mem_usage, 2
        )
2469

2470
2471
2472
2473
2474
2475
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )
        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
Liangsheng Yin's avatar
Liangsheng Yin committed
2476
2477

        return GetInternalStateReqOutput(internal_state=ret)
2478
2479
2480
2481
2482

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
2483
                "max_micro_batch_size",
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
2494
2495
2496
2497
2498
2499
2500
2501
            elif k == "max_micro_batch_size" and (
                v > self.max_running_requests // self.pp_size or v < 1
            ):
                logging.warning(
                    f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
                )
                if_success = False
                break
2502
2503
2504
2505
2506
2507
2508
2509
2510
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
2511
            logger.info(f"Global server args updated! {global_server_args_dict=}")
2512
2513
2514
2515
2516
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
    def handle_rpc_request(self, recv_req: RpcReqInput):
        # Handle RPC requests
        logger.info(
            f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
        )

        success = True
        exec = None
        try:
            func = getattr(self, recv_req.method)
            func(recv_req.parameters)
        except Exception as e:
            success = False
            exec = e
            logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")

        barrier()
        return RpcReqOutput(success, "" if not exec else str(exec))

2536
2537
    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
Lianmin Zheng's avatar
Lianmin Zheng committed
2538
        to_del = []
2539
        for i, req in enumerate(self.waiting_queue):
2540
            if recv_req.abort_all or req.rid.startswith(recv_req.rid):
Lianmin Zheng's avatar
Lianmin Zheng committed
2541
                to_del.append(i)
2542

Lianmin Zheng's avatar
Lianmin Zheng committed
2543
        # Sort in reverse order to avoid index issues when deleting
Lianmin Zheng's avatar
Lianmin Zheng committed
2544
        for i in reversed(to_del):
2545
2546
2547
            # Abort method 1: directly pop from the queue
            # This only works for requests that have not started anything.
            # We still need to send something back to TokenizerManager to clean up the state.
Lianmin Zheng's avatar
Lianmin Zheng committed
2548
            req = self.waiting_queue.pop(i)
2549
2550
2551
            if self.enable_hicache_storage:
                # to release prefetch events associated with the request
                self.tree_cache.release_aborted_request(req.rid)
Lianmin Zheng's avatar
Lianmin Zheng committed
2552
            self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
2553
2554
2555
2556
            # For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
            if self.disaggregation_mode == DisaggregationMode.DECODE:
                self.tree_cache.cache_finished_req(req)

2557
            logger.debug(f"Abort queued request. {req.rid=}")
2558

2559
2560
2561
2562
2563
        # Delete the requests in the grammar queue
        for req in self.grammar_queue:
            # Abort method 2: call `set_finish_with_abort`
            # The request will still run one prefill forward pass.
            # In this case, we change the input_ids to be only one token to make this prefill cheap.
2564
            if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2565
                logger.debug(f"Abort grammar queue request. {req.rid=}")
2566
2567
                if req.grammar:
                    req.grammar.cancel()
2568
2569
                req.set_finish_with_abort("Aborted by AbortReq.")

2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
2598
2599
2600
        # Delete requests not in the waiting queue when PD disaggregation is enabled
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            # Abort requests that have not yet been bootstrapped
            for i, req in enumerate(self.disagg_prefill_bootstrap_queue.queue):
                logger.debug(f"Abort bootstrap queue request. {req.rid=}")
                if recv_req.abort_all or req.rid.startswith(recv_req.rid):
                    if hasattr(req.disagg_kv_sender, "abort"):
                        req.disagg_kv_sender.abort()

            # Abort in-flight requests
            for i, req in enumerate(self.disagg_prefill_inflight_queue):
                logger.debug(f"Abort inflight queue request. {req.rid=}")
                if recv_req.abort_all or req.rid.startswith(recv_req.rid):
                    if hasattr(req.disagg_kv_sender, "abort"):
                        req.disagg_kv_sender.abort()

        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            # Abort requests that have not yet finished preallocation
            for i, decode_req in enumerate(self.disagg_decode_prealloc_queue.queue):
                logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
                if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
                    if hasattr(decode_req.kv_receiver, "abort"):
                        decode_req.kv_receiver.abort()

            # Abort requests waiting for kvcache to release tree cache
            for i, decode_req in enumerate(self.disagg_decode_transfer_queue.queue):
                logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
                if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
                    if hasattr(decode_req.kv_receiver, "abort"):
                        decode_req.kv_receiver.abort()

2601
        # Delete requests in the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
2602
2603
2604
2605
2606
2607
        if self.cur_batch is self.running_batch or self.cur_batch is None:
            reqs = self.running_batch.reqs
        else:
            reqs = self.running_batch.reqs + self.cur_batch.reqs

        for req in reqs:
2608
2609
2610
            if not req.finished() and (
                recv_req.abort_all or req.rid.startswith(recv_req.rid)
            ):
2611
2612
2613
                # Abort method 3: set `to_abort=True`
                # The request will still run one decode forward pass.
                # Then we reuse all existing code to clean up the KV cache allocation.
Lianmin Zheng's avatar
Lianmin Zheng committed
2614
2615
                logger.debug(f"Abort running request. {req.rid=}")
                req.to_abort = True
2616

2617
2618
2619
    def _pause_engine(self) -> Tuple[List[Req], int]:
        raise NotImplementedError()

2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
    def load_lora_adapter(
        self, recv_req: LoadLoRAAdapterReqInput
    ) -> LoadLoRAAdapterReqOutput:
        """In-place loading a new lora adapter from disk or huggingface."""

        result = self.tp_worker.load_lora_adapter(recv_req)
        return result

    def unload_lora_adapter(
        self, recv_req: UnloadLoRAAdapterReqInput
    ) -> UnloadLoRAAdapterReqOutput:
        """Unload the lora adapter."""

        result = self.tp_worker.unload_lora_adapter(recv_req)
        return result

2636
2637
2638
2639
    def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
        self.send_to_detokenizer.send_pyobj(recv_req)
        return recv_req

2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
    def init_weights_send_group_for_remote_instance(
        self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
    ):
        """Init the seed and client instance communication group."""
        success, message = self.tp_worker.init_weights_send_group_for_remote_instance(
            recv_req
        )
        return InitWeightsSendGroupForRemoteInstanceReqOutput(success, message)

    def send_weights_to_remote_instance(
        self, recv_req: SendWeightsToRemoteInstanceReqInput
    ):
        """Send the seed instance weights to the destination instance."""
        success, message = self.tp_worker.send_weights_to_remote_instance(recv_req)
        return SendWeightsToRemoteInstanceReqOutput(success, message)

2656
2657
2658
2659
2660
2661
2662
    def slow_down(self, recv_req: SlowDownReqInput):
        t = recv_req.forward_sleep_time
        if t is not None and t <= 0:
            t = None
        self.forward_sleep_time = t
        return SlowDownReqOutput()

2663
2664
    def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
        if recv_req == ExpertDistributionReq.START_RECORD:
2665
            get_global_expert_distribution_recorder().start_record()
2666
        elif recv_req == ExpertDistributionReq.STOP_RECORD:
2667
            get_global_expert_distribution_recorder().stop_record()
2668
        elif recv_req == ExpertDistributionReq.DUMP_RECORD:
2669
            get_global_expert_distribution_recorder().dump_record()
2670
        else:
2671
            raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
2672
        return ExpertDistributionReqOutput()
2673

2674
    def open_session(self, recv_req: OpenSessionReqInput):
2675
2676
2677
2678
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
2679
            return OpenSessionReqOutput(session_id, False)
2680
        elif session_id is None:
2681
            logger.warning("session id is None, cannot open.")
2682
            return OpenSessionReqOutput(session_id, False)
2683
2684
2685
2686
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
2687
            return OpenSessionReqOutput(session_id, True)
2688
2689
2690
2691
2692
2693
2694
2695
2696

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

2697
2698
    def get_print_prefix(self):
        prefix = ""
2699
2700
        if self.attn_dp_rank is not None:
            prefix += f" DP{self.attn_dp_rank}"
2701
2702
2703
2704
2705
2706
        if self.server_args.tp_size > 1:
            prefix += f" TP{self.tp_rank}"
        if self.pp_size > 1:
            prefix += f" PP{self.pp_rank}"
        return prefix

2707
2708
    def current_scheduler_metrics_enabled(self):
        return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
2709

2710
2711
2712
    def maybe_sleep_on_idle(self):
        if self.idle_sleeper is not None:
            self.idle_sleeper.maybe_sleep()
2713

2714
2715
2716
2717
2718
2719
    def handle_freeze_gc(self, recv_req: FreezeGCReq):
        """Handle freeze_gc request: freeze scheduler's GC and forward to detokenizer."""
        freeze_gc("Scheduler")
        self.send_to_detokenizer.send_pyobj(recv_req)
        return None

2720

2721
2722
2723
2724
2725
2726
2727
class IdleSleeper:
    """
    In setups which have long inactivity periods it is desirable to reduce
    system power consumption when sglang does nothing. This would lead not only
    to power savings, but also to more CPU thermal headroom when a request
    eventually comes. This is important in cases when multiple GPUs are connected
    as each GPU would otherwise pin one thread at 100% CPU usage.
2728

2729
2730
2731
    The simplest solution is to use zmq.Poller on all sockets that may receive
    data that needs handling immediately.
    """
2732

2733
2734
    def __init__(self, sockets):
        self.poller = zmq.Poller()
2735
        self.last_empty_time = time.time()
2736
2737
2738
2739
2740
        for s in sockets:
            self.poller.register(s, zmq.POLLIN)

    def maybe_sleep(self):
        self.poller.poll(1000)
2741
2742
2743
2744
2745
2746
2747
        if (
            global_config.torch_empty_cache_interval > 0
            and time.time() - self.last_empty_time
            > global_config.torch_empty_cache_interval
        ):
            self.last_empty_time = time.time()
            torch.cuda.empty_cache()
2748

2749

2750
2751
def is_health_check_generate_req(recv_req):
    return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
2752

2753
2754

def is_work_request(recv_req):
2755
2756
2757
2758
2759
2760
2761
2762
2763
    return isinstance(
        recv_req,
        (
            TokenizedGenerateReqInput,
            TokenizedEmbeddingReqInput,
            BatchTokenizedGenerateReqInput,
            BatchTokenizedEmbeddingReqInput,
        ),
    )
2764
2765


2766
2767
2768
2769
2770
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
Cheng Wan's avatar
Cheng Wan committed
2771
    moe_ep_rank: int,
2772
    pp_rank: int,
2773
    dp_rank: Optional[int],
2774
    pipe_writer,
2775
    balance_meta: Optional[DPBalanceMeta] = None,
2776
):
2777
2778
2779
2780
2781
2782
    if server_args.enable_trace:
        process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
        if server_args.disaggregation_mode == "null":
            thread_label = "Scheduler"
            trace_set_thread_info(thread_label, tp_rank, dp_rank)

2783
2784
2785
    if (numa_node := server_args.numa_node) is not None:
        numa_bind_to_node(numa_node[gpu_id])

2786
    # Generate the prefix
2787
2788
2789
2790
2791
    prefix = ""
    if dp_rank is not None:
        prefix += f" DP{dp_rank}"
    if server_args.tp_size > 1:
        prefix += f" TP{tp_rank}"
Cheng Wan's avatar
Cheng Wan committed
2792
2793
    if server_args.ep_size > 1:
        prefix += f" EP{moe_ep_rank}"
2794
2795
    if server_args.pp_size > 1:
        prefix += f" PP{pp_rank}"
2796

2797
    # Config the process
2798
    setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
2799
    faulthandler.enable()
2800
    kill_itself_when_parent_died()
2801
    parent_process = psutil.Process().parent()
2802

2803
2804
2805
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
2806

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
2807
    # Configure the logger
2808
    configure_logger(server_args, prefix=prefix)
2809
    suppress_other_loggers()
2810

2811
    # Set cpu affinity to this gpu process
2812
2813
2814
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

2815
    # Create a scheduler and run the event loop
2816
    try:
Cheng Wan's avatar
Cheng Wan committed
2817
        scheduler = Scheduler(
2818
2819
2820
2821
2822
2823
2824
2825
            server_args,
            port_args,
            gpu_id,
            tp_rank,
            moe_ep_rank,
            pp_rank,
            dp_rank,
            dp_balance_meta=balance_meta,
Cheng Wan's avatar
Cheng Wan committed
2826
        )
2827
        pipe_writer.send(
Mick's avatar
Mick committed
2828
2829
2830
2831
2832
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2833
        )
Byron Hsu's avatar
Byron Hsu committed
2834

2835
        disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
Byron Hsu's avatar
Byron Hsu committed
2836
        if disaggregation_mode == DisaggregationMode.NULL:
2837
2838
2839
            if server_args.pp_size > 1:
                scheduler.event_loop_pp()
            elif scheduler.enable_overlap:
Byron Hsu's avatar
Byron Hsu committed
2840
2841
2842
2843
                scheduler.event_loop_overlap()
            else:
                scheduler.event_loop_normal()
        elif disaggregation_mode == DisaggregationMode.PREFILL:
2844
2845
2846
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_prefill()
            else:
2847
2848
2849
2850
                if server_args.pp_size > 1:
                    scheduler.event_loop_pp_disagg_prefill()
                else:
                    scheduler.event_loop_normal_disagg_prefill()
2851

Byron Hsu's avatar
Byron Hsu committed
2852
        elif disaggregation_mode == DisaggregationMode.DECODE:
2853
2854
2855
2856
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_decode()
            else:
                scheduler.event_loop_normal_disagg_decode()
Byron Hsu's avatar
Byron Hsu committed
2857

2858
    except Exception:
2859
2860
2861
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)