scheduler.py 120 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
import time
23
from collections import deque
Lianmin Zheng's avatar
Lianmin Zheng committed
24
from concurrent import futures
25
from dataclasses import dataclass
26
from http import HTTPStatus
27
from types import SimpleNamespace
28
from typing import Deque, Dict, List, Optional, Tuple, Union
29

30
import psutil
31
import setproctitle
32
import torch
33
import zmq
34
35
from torch.cuda import Stream as CudaStream
from torch.cuda import StreamContext as CudaStreamContext
36
from torch.distributed import barrier
37

38
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
39
from sglang.srt.configs.model_config import ModelConfig
40
41
42
43
from sglang.srt.constrained.base_grammar_backend import (
    INVALID_GRAMMAR_OBJ,
    create_grammar_backend,
)
Byron Hsu's avatar
Byron Hsu committed
44
45
46
47
48
from sglang.srt.disaggregation.decode import (
    DecodePreallocQueue,
    DecodeTransferQueue,
    SchedulerDisaggregationDecodeMixin,
)
49
50
51
from sglang.srt.disaggregation.decode_kvcache_offload_manager import (
    DecodeKVCacheOffloadManager,
)
Byron Hsu's avatar
Byron Hsu committed
52
53
54
55
56
57
from sglang.srt.disaggregation.prefill import (
    PrefillBootstrapQueue,
    SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
58
    MetadataBuffers,
Byron Hsu's avatar
Byron Hsu committed
59
    ReqToMetadataIdxAllocator,
60
    TransferBackend,
61
    prepare_abort,
Byron Hsu's avatar
Byron Hsu committed
62
)
63
from sglang.srt.distributed import get_pp_group, get_world_group
fzyzcjy's avatar
fzyzcjy committed
64
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
65
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
66
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
67
from sglang.srt.layers.moe import initialize_moe_config
68
69
from sglang.srt.managers.io_struct import (
    AbortReq,
70
71
    BatchTokenizedEmbeddingReqInput,
    BatchTokenizedGenerateReqInput,
72
73
    ClearHiCacheReqInput,
    ClearHiCacheReqOutput,
74
    CloseSessionReqInput,
75
    DestroyWeightsUpdateGroupReqInput,
76
    ExpertDistributionReq,
77
    ExpertDistributionReqOutput,
78
    ExpertDistributionReqType,
79
80
    FlushCacheReqInput,
    FlushCacheReqOutput,
81
    FreezeGCReq,
82
83
    GetInternalStateReq,
    GetInternalStateReqOutput,
84
85
    GetLoadReqInput,
    GetLoadReqOutput,
86
    GetWeightsByNameReqInput,
87
    HealthCheckOutput,
88
89
    InitWeightsSendGroupForRemoteInstanceReqInput,
    InitWeightsSendGroupForRemoteInstanceReqOutput,
90
    InitWeightsUpdateGroupReqInput,
91
92
    LoadLoRAAdapterReqInput,
    LoadLoRAAdapterReqOutput,
93
    MultiTokenizerRegisterReq,
94
    MultiTokenizerWrapper,
95
96
    OpenSessionReqInput,
    OpenSessionReqOutput,
97
    ProfileReq,
98
99
    ReleaseMemoryOccupationReqInput,
    ResumeMemoryOccupationReqInput,
100
101
    RpcReqInput,
    RpcReqOutput,
102
103
    SendWeightsToRemoteInstanceReqInput,
    SendWeightsToRemoteInstanceReqOutput,
104
105
    SetInternalStateReq,
    SetInternalStateReqOutput,
106
107
    SlowDownReqInput,
    SlowDownReqOutput,
108
109
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
110
111
    UnloadLoRAAdapterReqInput,
    UnloadLoRAAdapterReqOutput,
Chayenne's avatar
Chayenne committed
112
    UpdateWeightFromDiskReqInput,
113
    UpdateWeightsFromDistributedReqInput,
114
    UpdateWeightsFromTensorReqInput,
115
)
116
from sglang.srt.managers.mm_utils import init_embedding_cache
117
from sglang.srt.managers.overlap_utils import FutureIndices, FutureMap
118
119
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
120
    ModelWorkerBatch,
Mick's avatar
Mick committed
121
    MultimodalInputs,
122
    Req,
123
    RequestStage,
124
    ScheduleBatch,
125
    global_server_args_dict,
126
)
127
128
129
130
131
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
fzyzcjy's avatar
fzyzcjy committed
132
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
133
134
135
136
from sglang.srt.managers.scheduler_metrics_mixin import (
    RECORD_STEP_TIME,
    SchedulerMetricsMixin,
)
137
138
139
from sglang.srt.managers.scheduler_output_processor_mixin import (
    SchedulerOutputProcessorMixin,
)
140
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
141
from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
142
143
144
from sglang.srt.managers.scheduler_update_weights_mixin import (
    SchedulerUpdateWeightsMixin,
)
145
from sglang.srt.managers.session_controller import Session
146
from sglang.srt.managers.utils import validate_input_length
tarinkk's avatar
tarinkk committed
147
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
148
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
149
from sglang.srt.mem_cache.radix_cache import RadixCache
Hanming Lu's avatar
Hanming Lu committed
150
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
151
from sglang.srt.model_executor.forward_batch_info import (
152
    ForwardBatch,
153
154
155
    ForwardMode,
    PPProxyTensors,
)
156
from sglang.srt.parser.reasoning_parser import ReasoningParser
157
from sglang.srt.server_args import PortArgs, ServerArgs
158
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
159
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
160
161
162
163
from sglang.srt.tracing.trace import (
    process_tracing_init,
    trace_set_proc_propagate_context,
    trace_set_thread_info,
164
    trace_slice_batch,
165
166
167
    trace_slice_end,
    trace_slice_start,
)
168
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
169
from sglang.srt.utils import (
170
    DynamicGradMode,
171
    broadcast_pyobj,
fzyzcjy's avatar
fzyzcjy committed
172
    configure_gc_logger,
173
    configure_logger,
Lianmin Zheng's avatar
Lianmin Zheng committed
174
    disable_request_logging,
175
    freeze_gc,
176
    get_available_gpu_memory,
177
    get_bool_env_var,
178
    get_int_env_var,
179
    get_zmq_socket,
Lianmin Zheng's avatar
Lianmin Zheng committed
180
    kill_itself_when_parent_died,
181
    numa_bind_to_node,
182
    point_to_point_pyobj,
183
    pyspy_dump_schedulers,
184
185
    require_mlp_sync,
    require_mlp_tp_gather,
186
    set_gpu_proc_affinity,
187
188
189
    set_random_seed,
    suppress_other_loggers,
)
190
191
192
193
194
from sglang.srt.utils.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
195
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
196
197
198

logger = logging.getLogger(__name__)

199
# Test retract decode for debugging purposes
200
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
201
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
202

203

204
205
@dataclass
class GenerationBatchResult:
206
207
208
209
210
    logits_output: Optional[LogitsProcessorOutput] = None
    pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] = None
    next_token_ids: Optional[torch.Tensor] = None
    num_accepted_tokens: Optional[int] = None
    can_run_cuda_graph: bool = False
211
212

    # For output processing
213
214
215
216
217
218
219
    extend_input_len_per_req: Optional[List[int]] = None
    extend_logprob_start_len_per_req: Optional[List[int]] = None

    # For overlap scheduling
    copy_done: Optional[torch.cuda.Event] = None
    delay_sample_launch: bool = False
    forward_batch: Optional[ForwardBatch] = None
220
    future_indices: Optional[FutureIndices] = None
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241

    def copy_to_cpu(self, return_logprob: bool = False):
        """Copy tensors to CPU in overlap scheduling.
        Only the tensors which are needed for processing results are copied,
        e.g., next_token_ids, logits outputs
        """
        if return_logprob:
            if self.logits_output.next_token_logits is not None:
                self.logits_output.next_token_logits = (
                    self.logits_output.next_token_logits.to("cpu", non_blocking=True)
                )
            if self.logits_output.input_token_logprobs is not None:
                self.logits_output.input_token_logprobs = (
                    self.logits_output.input_token_logprobs.to("cpu", non_blocking=True)
                )
        if self.logits_output.hidden_states is not None:
            self.logits_output.hidden_states = self.logits_output.hidden_states.to(
                "cpu", non_blocking=True
            )
        self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True)
        self.copy_done.record()
242
243
244
245
246

    @classmethod
    def from_pp_proxy(
        cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
    ):
247
        # TODO(lsyin): refactor PP and avoid using dict
248
249
250
251
252
253
254
255
256
257
258
        proxy_dict = next_pp_outputs.tensors
        return cls(
            logits_output=logits_output,
            pp_hidden_states_proxy_tensors=None,
            next_token_ids=next_pp_outputs["next_token_ids"],
            extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
            extend_logprob_start_len_per_req=proxy_dict.get(
                "extend_logprob_start_len_per_req", None
            ),
            can_run_cuda_graph=can_run_cuda_graph,
        )
259
260
261
262
263
264
265


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor


Byron Hsu's avatar
Byron Hsu committed
266
267
class Scheduler(
    SchedulerOutputProcessorMixin,
268
269
270
    SchedulerUpdateWeightsMixin,
    SchedulerProfilerMixin,
    SchedulerMetricsMixin,
Byron Hsu's avatar
Byron Hsu committed
271
272
273
    SchedulerDisaggregationDecodeMixin,
    SchedulerDisaggregationPrefillMixin,
):
274
275
276
277
278
279
280
281
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
Cheng Wan's avatar
Cheng Wan committed
282
        moe_ep_rank: int,
283
        pp_rank: int,
284
        dp_rank: Optional[int],
285
286
    ):
        # Parse args
287
        self.server_args = server_args
288
        self.tp_rank = tp_rank
Cheng Wan's avatar
Cheng Wan committed
289
        self.moe_ep_rank = moe_ep_rank
290
        self.pp_rank = pp_rank
291
        self.dp_rank = dp_rank
292
        self.tp_size = server_args.tp_size
Cheng Wan's avatar
Cheng Wan committed
293
        self.moe_ep_size = server_args.ep_size
294
295
        self.pp_size = server_args.pp_size
        self.dp_size = server_args.dp_size
296
        self.schedule_policy = server_args.schedule_policy
297
298
299
300
301
302
303
        self.enable_priority_scheduling = server_args.enable_priority_scheduling
        self.schedule_low_priority_values_first = (
            server_args.schedule_low_priority_values_first
        )
        self.priority_scheduling_preemption_threshold = (
            server_args.priority_scheduling_preemption_threshold
        )
304
        self.enable_lora = server_args.enable_lora
305
        self.max_loras_per_batch = server_args.max_loras_per_batch
306
        self.enable_overlap = not server_args.disable_overlap_schedule
307
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
308
        self.enable_metrics = server_args.enable_metrics
309
310
311
        self.enable_metrics_for_all_schedulers = (
            server_args.enable_metrics_for_all_schedulers
        )
312
313
314
        self.enable_kv_cache_events = bool(
            server_args.kv_events_config and tp_rank == 0
        )
315
        self.enable_trace = server_args.enable_trace
316
        self.stream_interval = server_args.stream_interval
317
318
319
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
320
321
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
322
        self.enable_hicache_storage = server_args.hicache_storage_backend is not None
Lianmin Zheng's avatar
Lianmin Zheng committed
323
        self.page_size = server_args.page_size
324

325
        self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
326
327
328
329
330
331
332
333
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

334
335
336
        # Init model config
        self.model_config = ModelConfig.from_server_args(server_args)

337
338
        # Init inter-process communication
        context = zmq.Context(2)
339
        self.idle_sleeper = None
340
        if self.pp_rank == 0 and self.attn_tp_rank == 0:
341
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
342
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
343
            )
344
345
346
347
            self.recv_from_rpc = get_zmq_socket(
                context, zmq.DEALER, port_args.rpc_ipc_name, False
            )

348
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
349
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
350
            )
351
            if server_args.skip_tokenizer_init:
352
                # Directly send to the TokenizerManager
353
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
354
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
355
356
                )
            else:
357
                # Send to the DetokenizerManager
358
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
359
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
360
                )
361

362
363
364
365
366
367
368
            if self.server_args.sleep_on_idle:
                self.idle_sleeper = IdleSleeper(
                    [
                        self.recv_from_tokenizer,
                        self.recv_from_rpc,
                    ]
                )
369
        else:
370
            self.recv_from_tokenizer = None
371
            self.recv_from_rpc = None
372
373
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
374

375
376
377
378
379
        if self.current_scheduler_metrics_enabled():
            self.send_metrics_from_scheduler = get_zmq_socket(
                context, zmq.PUSH, port_args.metrics_ipc_name, False
            )

380
        # Init tokenizer
381
        self.init_tokenizer()
382

383
384
385
        # Init moe config
        self.init_moe_config()

386
387
388
389
390
391
392
393
394
        # Set reasoning_parser and think_end_id if --reasoning_parser is enabled
        if self.server_args.reasoning_parser and self.tokenizer:
            reasoning_parser = ReasoningParser(
                model_type=self.server_args.reasoning_parser, stream_reasoning=False
            )
            self.tokenizer.think_end_id = self.tokenizer.encode(
                reasoning_parser.detector.think_end_token, add_special_tokens=False
            )[0]

395
396
397
398
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
399

400
        # Launch a tensor parallel worker
401

402
403
404
        from sglang.srt.managers.tp_worker import TpModelWorker

        self.tp_worker = TpModelWorker(
405
            server_args=server_args,
406
407
            gpu_id=gpu_id,
            tp_rank=tp_rank,
Cheng Wan's avatar
Cheng Wan committed
408
            moe_ep_rank=moe_ep_rank,
409
            pp_rank=pp_rank,
410
            dp_rank=dp_rank,
411
            nccl_port=port_args.nccl_port,
412
        )
413

414
        # Launch a draft worker for speculative decoding
415
416
417
418
419
420
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
Cheng Wan's avatar
Cheng Wan committed
421
                moe_ep_rank=moe_ep_rank,
422
423
424
425
426
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
427
428
429
430
431
432
433
434
435
436
437
438
        elif self.spec_algorithm.is_standalone():
            from sglang.srt.speculative.standalone_worker import StandaloneWorker

            self.draft_worker = StandaloneWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                moe_ep_rank=moe_ep_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
439
440
        elif self.spec_algorithm.is_ngram():
            from sglang.srt.speculative.ngram_worker import NGRAMWorker
441

442
            self.draft_worker = NGRAMWorker(
443
444
445
446
447
448
449
450
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                moe_ep_rank=moe_ep_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
451
452
453
        else:
            self.draft_worker = None

454
455
456
457
458
459
        # Dispatch the model worker
        if self.spec_algorithm.is_none():
            self.model_worker = self.tp_worker
        else:
            self.model_worker = self.draft_worker

460
        # Get token and memory info from the model worker
461
462
463
464
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
465
            self.max_queued_requests,
466
            self.max_req_len,
467
468
            self.max_req_input_len,
            self.random_seed,
469
            self.device,
470
471
472
473
474
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
475
476
        if global_server_args_dict["pp_max_micro_batch_size"] is None:
            global_server_args_dict["pp_max_micro_batch_size"] = max(
477
478
479
480
481
482
                self.max_running_requests // server_args.pp_size, 1
            )

        self.tp_group = self.tp_worker.get_tp_group()
        self.tp_cpu_group = self.tp_group.cpu_group
        self.attn_tp_group = self.tp_worker.get_attention_tp_group()
483
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
484
485
486
        self.pp_group = get_pp_group()
        self.world_group = get_world_group()

487
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
488
        global_server_args_dict.update(worker_global_server_args_dict)
489
        set_random_seed(self.random_seed)
490

491
        # Hybrid memory pool
Hanming Lu's avatar
Hanming Lu committed
492
493
494
495
496
497
498
        self.is_hybrid = self.tp_worker.is_hybrid
        if self.is_hybrid:
            self.sliding_window_size = self.tp_worker.sliding_window_size
            self.full_tokens_per_layer, self.swa_tokens_per_layer = (
                self.tp_worker.get_tokens_per_layer_info()
            )

499
        # Print debug info
500
        if tp_rank == 0:
501
502
503
            avail_mem = get_available_gpu_memory(
                self.device, self.gpu_id, empty_cache=False
            )
504
505
506
507
508
            logger.info(
                f"max_total_num_tokens={self.max_total_num_tokens}, "
                f"chunked_prefill_size={server_args.chunked_prefill_size}, "
                f"max_prefill_tokens={self.max_prefill_tokens}, "
                f"max_running_requests={self.max_running_requests}, "
509
                f"context_len={self.model_config.context_len}, "
510
                f"{'available_cpu_mem' if self.device == 'cpu' else 'available_gpu_mem'}={avail_mem:.2f} GB"
511
            )
512

Lianmin Zheng's avatar
Lianmin Zheng committed
513
        # Init memory pool and cache
514
        self.init_memory_pool_and_cache()
515
516
517

        # Init running status
        self.waiting_queue: List[Req] = []
518
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
519
        self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
520
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
521
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
522
        # The last forward batch
523
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
524
525
        self.forward_ct = 0
        self.forward_ct_decode = 0
526
        self.num_generated_tokens = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
527
        self.last_prefill_tokens = 0
528
529
        self.last_decode_stats_tic = time.perf_counter()
        self.last_prefill_stats_tic = time.perf_counter()
530
        self.return_health_check_ct = 0
531
532
533
534
535
        self.num_retracted_reqs: int = 0
        self.num_paused_reqs: int = 0
        self.kv_transfer_speed_gb_s: float = 0.0
        self.kv_transfer_latency_ms: float = 0.0
        self.sessions: Dict[str, Session] = {}
536
537
538
        self.default_stream: CudaStream = torch.get_device_module(
            self.device
        ).current_stream()
539
        if self.device == "cpu":
540
            self.default_stream.synchronize = lambda: None  # No-op for CPU
541
        self.forward_sleep_time = None
542

543
544
        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
545
546
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
547
        self.chunked_req = None
548
549
550
551
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
552
        # Init the grammar backend for constrained generation
553
        self.grammar_queue: List[Req] = []
554
        if not server_args.skip_tokenizer_init:
555
            self.grammar_backend = create_grammar_backend(
556
557
558
559
                server_args,
                self.tokenizer,
                self.model_config.vocab_size,
                self.model_config.hf_eos_token_id,
560
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
561
562
        else:
            self.grammar_backend = None
563

564
        # Init schedule policy and new token estimation
565
        self.policy = SchedulePolicy(
Lianmin Zheng's avatar
Lianmin Zheng committed
566
567
568
            self.schedule_policy,
            self.tree_cache,
            self.enable_hierarchical_cache,
569
570
            self.enable_priority_scheduling,
            self.schedule_low_priority_values_first,
571
        )
572
573
574
        # Enable preemption for priority scheduling.
        self.try_preemption = self.enable_priority_scheduling

575
576
577
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
578
579
        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
580
581
            * server_args.schedule_conservativeness,
            1.0,
582
        )
583
584
585
586
587
588
589
590
591
592
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
593
594
595
596
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
597
        self.parent_process = psutil.Process().parent()
598
599

        # Init memory saver, profiler and metric stats
600
601
602
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )
603
        self.offload_tags = set()
limingshu's avatar
limingshu committed
604
        self.init_profiler()
605

606
        self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
fzyzcjy's avatar
fzyzcjy committed
607
608
609
610
611
612
        self.input_blocker = (
            SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
            if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
            else None
        )

613
        # Init metrics stats
614
        self.init_metrics(tp_rank, pp_rank, dp_rank)
615

616
617
618
        if self.enable_kv_cache_events:
            self.init_kv_events(server_args.kv_events_config)

619
620
621
622
623
624
625
626
627
        # Init disaggregation
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
        self.init_disaggregation()

        if get_bool_env_var("SGLANG_GC_LOG"):
            configure_gc_logger()

628
629
        # Init prefill kv split size when deterministic inference is enabled with various attention backends
        self.init_deterministic_inference_config()
630

631
632
633
        # Init overlap
        self.init_overlap()

634
635
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
636
637
638
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
639
640
                (BatchTokenizedGenerateReqInput, self.handle_batch_generate_request),
                (BatchTokenizedEmbeddingReqInput, self.handle_batch_embedding_request),
641
                (FlushCacheReqInput, self.flush_cache_wrapped),
642
                (ClearHiCacheReqInput, self.clear_hicache_storage_wrapped),
643
                (AbortReq, self.abort_request),
644
645
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
646
647
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
648
                (DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group),
649
650
651
652
653
654
655
656
                (
                    InitWeightsSendGroupForRemoteInstanceReqInput,
                    self.init_weights_send_group_for_remote_instance,
                ),
                (
                    SendWeightsToRemoteInstanceReqInput,
                    self.send_weights_to_remote_instance,
                ),
657
658
659
660
661
662
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
663
664
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
665
                (SlowDownReqInput, self.slow_down),
666
                (ProfileReq, self.profile),
667
                (FreezeGCReq, self.handle_freeze_gc),
668
                (GetInternalStateReq, self.get_internal_state),
669
                (SetInternalStateReq, self.set_internal_state),
670
                (RpcReqInput, self.handle_rpc_request),
671
                (ExpertDistributionReq, self.expert_distribution_handle),
672
673
                (LoadLoRAAdapterReqInput, self.load_lora_adapter),
                (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
674
                (MultiTokenizerRegisterReq, self.register_multi_tokenizer),
675
                (GetLoadReqInput, self.get_load),
676
677
678
            ]
        )

679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
    def init_deterministic_inference_config(self):
        """Initialize deterministic inference configuration for different attention backends."""
        if not self.server_args.enable_deterministic_inference:
            self.truncation_align_size = None
            return

        backend_sizes = {
            "flashinfer": ("SGLANG_FLASHINFER_PREFILL_SPLIT_TILE_SIZE", 4096),
            "triton": ("SGLANG_TRITON_PREFILL_TRUNCATION_ALIGN_SIZE", 4096),
        }
        env_var, default_size = backend_sizes.get(
            self.server_args.attention_backend, (None, None)
        )
        self.truncation_align_size = (
            get_int_env_var(env_var, default_size) if env_var else None
        )

696
697
698
    def init_tokenizer(self):
        server_args = self.server_args
        self.is_generation = self.model_config.is_generation
699

700
701
702
703
704
705
706
707
708
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if self.model_config.is_multimodal:
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
709
                    use_fast=not server_args.disable_fast_image_processor,
710
                )
xm:D's avatar
xm:D committed
711
                self.tokenizer = get_tokenizer_from_processor(self.processor)
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )

    def init_memory_pool_and_cache(self):
        server_args = self.server_args

        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            self.tp_worker.get_memory_pool()
        )

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
Hanming Lu's avatar
Hanming Lu committed
731
            if self.is_hybrid:
tarinkk's avatar
tarinkk committed
732
733
734
735
                ChunkCacheClass = SWAChunkCache
            else:
                ChunkCacheClass = ChunkCache
            self.tree_cache = ChunkCacheClass(
736
737
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
738
                page_size=self.page_size,
739
740
            )
        else:
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
            if os.environ.get("SGLANG_EXPERIMENTAL_CPP_RADIX_TREE") == "1":
                # lazy import to avoid JIT overhead
                from sglang.srt.mem_cache.radix_cache_cpp import RadixCacheCpp

                self.tree_cache = RadixCacheCpp(
                    disable=False,
                    use_hicache=self.enable_hierarchical_cache,
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool=self.token_to_kv_pool_allocator,
                    tp_cache_group=self.tp_cpu_group,
                    page_size=self.page_size,
                    hicache_ratio=server_args.hicache_ratio,
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
                    enable_kv_cache_events=self.enable_kv_cache_events,
                )
            elif self.enable_hierarchical_cache:
758
759
760
                self.tree_cache = HiRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
761
762
763
764
765
                    tp_cache_group=(
                        self.attn_tp_cpu_group
                        if self.server_args.enable_dp_attention
                        else self.tp_cpu_group
                    ),
766
                    page_size=self.page_size,
767
                    eviction_policy=server_args.radix_eviction_policy,
768
                    hicache_ratio=server_args.hicache_ratio,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
769
770
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
771
                    hicache_io_backend=server_args.hicache_io_backend,
772
                    hicache_mem_layout=server_args.hicache_mem_layout,
773
                    enable_metrics=self.enable_metrics,
774
                    hicache_storage_backend=server_args.hicache_storage_backend,
pansicheng's avatar
pansicheng committed
775
                    hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
776
777
                    model_name=server_args.served_model_name,
                    storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
Ke Bao's avatar
Ke Bao committed
778
                    is_eagle=self.spec_algorithm.is_eagle(),
779
                )
780
781
782
                self.tp_worker.register_hicache_layer_transfer_counter(
                    self.tree_cache.cache_controller.layer_done_counter
                )
Hanming Lu's avatar
Hanming Lu committed
783
784
785
786
787
788
789
790
791
792
            elif self.is_hybrid:
                assert (
                    self.server_args.disaggregation_mode == "null"
                ), "Hybrid mode does not support disaggregation yet"
                self.tree_cache = SWARadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                    sliding_window_size=self.sliding_window_size,
                    page_size=self.page_size,
                    disable=server_args.disable_radix_cache,
793
                    is_eagle=self.spec_algorithm.is_eagle(),
Hanming Lu's avatar
Hanming Lu committed
794
                )
795
796
797
798
799
800
801
802
803
804
805
806
807
808
            elif server_args.enable_lmcache:
                from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
                    LMCRadixCache,
                )

                self.tree_cache = LMCRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                    page_size=self.page_size,
                    disable=server_args.disable_radix_cache,
                    model_config=self.model_config,
                    tp_size=self.tp_size,
                    rank=self.tp_rank,
                    tp_group=self.tp_group,
809
                    eviction_policy=server_args.radix_eviction_policy,
810
                )
811
812
813
814
            else:
                self.tree_cache = RadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
815
                    page_size=self.page_size,
816
                    disable=server_args.disable_radix_cache,
817
                    enable_kv_cache_events=self.enable_kv_cache_events,
818
                    eviction_policy=server_args.radix_eviction_policy,
Ke Bao's avatar
Ke Bao committed
819
                    is_eagle=self.spec_algorithm.is_eagle(),
820
821
                )

822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
        if (
            server_args.disaggregation_mode == "decode"
            and server_args.disaggregation_decode_enable_offload_kvcache
        ):
            self.decode_offload_manager = DecodeKVCacheOffloadManager(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                tp_group=(
                    self.attn_tp_cpu_group
                    if self.server_args.enable_dp_attention
                    else self.tp_cpu_group
                ),
                tree_cache=self.tree_cache,
                server_args=self.server_args,
            )
        else:
            self.decode_offload_manager = None

840
841
842
843
844
845
        self.decode_mem_cache_buf_multiplier = (
            1
            if self.spec_algorithm.is_none()
            else (
                server_args.speculative_num_draft_tokens
                + (
846
847
                    (server_args.speculative_eagle_topk or 1)
                    * (server_args.speculative_num_steps or 1)
848
849
                )
            )
850
        )
851

852
853
854
        embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
        init_embedding_cache(embedding_cache_size * 1024 * 1024)

Byron Hsu's avatar
Byron Hsu committed
855
    def init_disaggregation(self):
856
857
858
859
        self.transfer_backend = TransferBackend(
            self.server_args.disaggregation_transfer_backend
        )

Byron Hsu's avatar
Byron Hsu committed
860
861
862
863
        if (
            self.disaggregation_mode == DisaggregationMode.DECODE
        ):  # *2 for the headroom.
            buffer_size = (self.req_to_token_pool.size) * 2
Byron Hsu's avatar
Byron Hsu committed
864
            self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
Byron Hsu's avatar
Byron Hsu committed
865
866
                buffer_size
            )
867
868
            self.disagg_metadata_buffers = MetadataBuffers(
                buffer_size,
869
                hidden_size=self.model_config.hf_text_config.hidden_size,
870
                hidden_states_dtype=self.model_config.dtype,
871
872
                custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
            )
Byron Hsu's avatar
Byron Hsu committed
873
874
875

            # The decode requests polling kv cache
            self.disagg_decode_transfer_queue = DecodeTransferQueue(
876
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
877
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
878
                tp_rank=self.tp_rank,
879
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
880
881
                scheduler=self,
                tree_cache=self.tree_cache,
Byron Hsu's avatar
Byron Hsu committed
882
883
884
885
886
887
            )

            # The decode requests pending for pre-allocation
            self.disagg_decode_prealloc_queue = DecodePreallocQueue(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Byron Hsu's avatar
Byron Hsu committed
888
889
                draft_token_to_kv_pool=(
                    None
890
                    if self.draft_worker is None or self.spec_algorithm.is_ngram()
Byron Hsu's avatar
Byron Hsu committed
891
892
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
893
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
894
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
895
896
897
                scheduler=self,
                transfer_queue=self.disagg_decode_transfer_queue,
                tree_cache=self.tree_cache,
898
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
899
900
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
901
902
                dp_size=self.server_args.dp_size,
                gpu_id=self.gpu_id,
Byron Hsu's avatar
Byron Hsu committed
903
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
904
905
                max_total_num_tokens=self.max_total_num_tokens,
                prefill_pp_size=self.server_args.disaggregation_prefill_pp,
906
                num_reserved_decode_tokens=self.server_args.num_reserved_decode_tokens,
907
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
908
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
909

Byron Hsu's avatar
Byron Hsu committed
910
911
912
        elif self.disaggregation_mode == DisaggregationMode.PREFILL:
            # *2 for the headroom.
            buffer_size = self.max_running_requests * 2
Byron Hsu's avatar
Byron Hsu committed
913
            self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
Byron Hsu's avatar
Byron Hsu committed
914
915
                buffer_size
            )
916
917
            self.disagg_metadata_buffers = MetadataBuffers(
                buffer_size,
918
                hidden_size=self.model_config.hf_text_config.hidden_size,
919
                hidden_states_dtype=self.model_config.dtype,
920
921
                custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
            )
Byron Hsu's avatar
Byron Hsu committed
922

Liangsheng Yin's avatar
Liangsheng Yin committed
923
            self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
Byron Hsu's avatar
Byron Hsu committed
924
                token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
Byron Hsu's avatar
Byron Hsu committed
925
926
                draft_token_to_kv_pool=(
                    None
927
                    if self.draft_worker is None or self.spec_algorithm.is_ngram()
Byron Hsu's avatar
Byron Hsu committed
928
929
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
930
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
931
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
932
933
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
Byron Hsu's avatar
Byron Hsu committed
934
                gpu_id=self.gpu_id,
Byron Hsu's avatar
Byron Hsu committed
935
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
936
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
937
938
939
                max_total_num_tokens=self.max_total_num_tokens,
                decode_tp_size=self.server_args.disaggregation_decode_tp,
                decode_dp_size=self.server_args.disaggregation_decode_dp,
940
                scheduler=self,
Byron Hsu's avatar
Byron Hsu committed
941
942
943
                pp_rank=self.pp_rank,
                pp_size=self.pp_size,
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
944
945
            )
            # The prefill requests that are in the middle of kv sending
946
            self.disagg_prefill_inflight_queue: List[Req] = []
Byron Hsu's avatar
Byron Hsu committed
947

948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
    def init_overlap(self):
        if not self.enable_overlap:
            return

        self.forward_stream: CudaStream = torch.get_device_module(self.device).Stream()
        self.forward_stream_ctx: CudaStreamContext = torch.get_device_module(
            self.device
        ).stream(self.forward_stream)
        self.copy_stream: CudaStream = torch.get_device_module(self.device).Stream()
        self.copy_stream_ctx: CudaStreamContext = torch.get_device_module(
            self.device
        ).stream(self.copy_stream)

        self.future_map = FutureMap(self.max_running_requests, self.device)
        self.batch_record_buf = [None] * 2
        self.batch_record_ct = 0

    def record_batch_in_overlap(self, model_worker_batch: ModelWorkerBatch):
        # FIXME(lsyin): hacky way to keep a reference to avoid GPU tensors being freed by torch GC
        # NOTE: More Reliable: record all tensors into the forward stream
        # NOTE: - for all future tensors, we shall always read from future map
        #       - for all non-future tensors (produced only by schedule stream),
        #       we shall keep its reference not being release during all the forwarding pass
        self.batch_record_ct = (self.batch_record_ct + 1) % 2
        self.batch_record_buf[self.batch_record_ct] = model_worker_batch

974
975
976
977
    def init_moe_config(self):
        if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
            initialize_moe_config(self.server_args)

978
    @DynamicGradMode()
979
    def event_loop_normal(self):
980
        """A normal scheduler loop."""
981
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
982
983
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
984

985
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
986
            self.cur_batch = batch
987
988
989
990

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
991
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
992
                # When the server is idle, do self-check and re-init some states
993
                self.self_check_during_idle()
994
995

            self.last_batch = batch
996

997
    @DynamicGradMode()
Lianmin Zheng's avatar
Lianmin Zheng committed
998
    def event_loop_overlap(self):
999
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
1000
        self.result_queue: Deque[Tuple[ScheduleBatch, GenerationBatchResult]] = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
1001
1002

        while True:
1003
1004
            self.launch_last_batch_sample_if_needed()

Lianmin Zheng's avatar
Lianmin Zheng committed
1005
1006
1007
1008
1009
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
1010

Lianmin Zheng's avatar
Lianmin Zheng committed
1011
1012
            if batch:
                result = self.run_batch(batch)
1013
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
1014
1015

            if self.last_batch:
1016
                # Process the results of the last batch
1017
                tmp_batch, tmp_result = self.result_queue.popleft()
1018
                self.process_batch_result(tmp_batch, tmp_result)
Lianmin Zheng's avatar
Lianmin Zheng committed
1019
            elif batch is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1020
                # When the server is idle, do self-check and re-init some states
1021
                self.self_check_during_idle()
Lianmin Zheng's avatar
Lianmin Zheng committed
1022
1023
1024

            self.last_batch = batch

1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
    @DynamicGradMode()
    def event_loop_pp(self):
        """A non-overlap scheduler loop for pipeline parallelism."""
        mbs = [None] * self.pp_size
        last_mbs = [None] * self.pp_size
        self.running_mbs = [
            ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
        ]
        pp_outputs: Optional[PPProxyTensors] = None
        while True:
            server_is_idle = True
            for mb_id in range(self.pp_size):
                self.running_batch = self.running_mbs[mb_id]
                self.last_batch = last_mbs[mb_id]

                recv_reqs = self.recv_requests()
                self.process_input_requests(recv_reqs)
                mbs[mb_id] = self.get_next_batch_to_run()
                self.running_mbs[mb_id] = self.running_batch

                self.cur_batch = mbs[mb_id]
                if self.cur_batch:
                    server_is_idle = False
                    result = self.run_batch(self.cur_batch)

1050
                # (last rank) send the outputs to the next step
1051
1052
                if self.pp_group.is_last_rank:
                    if self.cur_batch:
1053
                        next_token_ids = result.next_token_ids
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
                        if self.cur_batch.return_logprob:
                            pp_outputs = PPProxyTensors(
                                {
                                    "next_token_ids": next_token_ids,
                                    "extend_input_len_per_req": result.extend_input_len_per_req,
                                    "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
                                }
                                | (
                                    {
                                        f"logits_output.{k}": v
                                        for k, v in result.logits_output.__dict__.items()
                                    }
                                    if result.logits_output is not None
                                    else {}
                                )
                            )
                        else:
                            pp_outputs = PPProxyTensors(
                                {
                                    "next_token_ids": next_token_ids,
                                }
                            )
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
                        # send the output from the last round to let the next stage worker run post processing
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                # receive outputs and post-process (filter finished reqs) the coming microbatch
                next_mb_id = (mb_id + 1) % self.pp_size
                next_pp_outputs = None
                if mbs[next_mb_id] is not None:
                    next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
                        self.pp_group.recv_tensor_dict(
                            all_gather_group=self.attn_tp_group
                        )
                    )
                    mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
1092
1093
1094
1095
1096
1097
1098
1099
1100
                    logits_output_args = {
                        k[len("logits_output.") :]: v
                        for k, v in next_pp_outputs.tensors.items()
                        if k.startswith("logits_output.")
                    }
                    if len(logits_output_args) > 0:
                        logits_output = LogitsProcessorOutput(**logits_output_args)
                    else:
                        logits_output = None
1101
1102

                    output_result = GenerationBatchResult.from_pp_proxy(
1103
                        logits_output=logits_output,
1104
                        next_pp_outputs=next_pp_outputs,
1105
                        can_run_cuda_graph=result.can_run_cuda_graph,
1106
1107
1108
1109
                    )
                    self.process_batch_result(mbs[next_mb_id], output_result)
                    last_mbs[next_mb_id] = mbs[next_mb_id]

1110
                # (not last rank)
1111
                if not self.pp_group.is_last_rank:
1112
1113
                    # carry the outputs to the next stage
                    # send the outputs from the last round to let the next stage worker run post processing
1114
1115
1116
1117
1118
1119
1120
                    if pp_outputs:
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                    # send out reqs to the next stage
1121
                    dp_offset = self.attn_dp_rank * self.attn_tp_size
1122
1123
1124
1125
                    if self.attn_tp_rank == 0:
                        point_to_point_pyobj(
                            recv_reqs,
                            self.pp_rank * self.tp_size + dp_offset,
1126
                            self.world_group.device_group,
1127
1128
1129
1130
1131
1132
                            self.pp_rank * self.tp_size + dp_offset,
                            (self.pp_rank + 1) * self.tp_size + dp_offset,
                        )

                    # send out proxy tensors to the next stage
                    if self.cur_batch:
1133
1134
                        # FIXME(lsyin): remove this assert
                        assert result.pp_hidden_states_proxy_tensors.tensors is not None
1135
                        self.pp_group.send_tensor_dict(
1136
                            result.pp_hidden_states_proxy_tensors.tensors,
1137
1138
1139
1140
1141
1142
1143
                            all_gather_group=self.attn_tp_group,
                        )

                pp_outputs = next_pp_outputs

            # When the server is idle, self-check and re-init some states
            if server_is_idle:
1144
1145
                # When the server is idle, do self-check and re-init some states
                self.self_check_during_idle()
1146

1147
1148
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
1149
1150
1151
1152
1153
1154
1155
1156

        if self.recv_skipper is not None:
            last_forward_mode = (
                self.last_batch.forward_mode if self.last_batch is not None else None
            )
            if not self.recv_skipper.handle(last_forward_mode):
                return []

1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
        if self.pp_rank == 0:
            if self.attn_tp_rank == 0:
                recv_reqs = []

                while True:
                    try:
                        recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_req)

                while True:
                    try:
                        recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_rpc)
            else:
                recv_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1176
        else:
1177
            if self.attn_tp_rank == 0:
1178
                dp_offset = self.attn_dp_rank * self.attn_tp_size
1179
1180
1181
                recv_reqs = point_to_point_pyobj(
                    [],
                    self.pp_rank * self.tp_size + dp_offset,
1182
                    self.world_group.device_group,
1183
1184
1185
1186
1187
                    (self.pp_rank - 1) * self.tp_size + dp_offset,
                    self.pp_rank * self.tp_size + dp_offset,
                )
            else:
                recv_reqs = None
1188

fzyzcjy's avatar
fzyzcjy committed
1189
1190
1191
        if self.input_blocker is not None:
            recv_reqs = self.input_blocker.handle(recv_reqs)

1192
1193
1194
1195
1196
1197
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
1198
1199
1200
1201
1202
1203
1204
                        req,
                        (
                            TokenizedGenerateReqInput,
                            TokenizedEmbeddingReqInput,
                            BatchTokenizedGenerateReqInput,
                            BatchTokenizedEmbeddingReqInput,
                        ),
1205
1206
1207
1208
1209
1210
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
1211
1212
1213
1214
1215
1216
1217
                        req,
                        (
                            TokenizedGenerateReqInput,
                            TokenizedEmbeddingReqInput,
                            BatchTokenizedGenerateReqInput,
                            BatchTokenizedEmbeddingReqInput,
                        ),
1218
1219
1220
1221
1222
1223
1224
1225
1226
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                work_reqs = broadcast_pyobj(
                    work_reqs,
1227
                    self.attn_tp_group.rank,
1228
                    self.attn_tp_cpu_group,
1229
                    src=self.attn_tp_group.ranks[0],
1230
1231
1232
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
1233
1234
1235
1236
                    control_reqs,
                    self.tp_group.rank,
                    self.tp_cpu_group,
                    src=self.tp_group.ranks[0],
1237
1238
1239
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
1240
1241
1242
1243
1244
1245
            recv_reqs = broadcast_pyobj(
                recv_reqs,
                self.tp_group.rank,
                self.tp_cpu_group,
                src=self.tp_group.ranks[0],
            )
1246

1247
1248
1249
1250
1251
1252
1253
        if self.enable_trace:
            for req in recv_reqs:
                if isinstance(
                    req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                ):
                    trace_set_proc_propagate_context(req.rid, req.trace_context)
                    trace_slice_start("", req.rid, anonymous=True)
1254

1255
1256
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
1257
    def process_input_requests(self, recv_reqs: List):
1258
        for recv_req in recv_reqs:
1259
1260
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
1261
1262
1263
                self.chunked_req is not None
                or not self.running_batch.is_empty()
                or len(self.offload_tags) > 0
1264
1265
1266
1267
            ):
                self.return_health_check_ct += 1
                continue

1268
1269
            # If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
            if isinstance(recv_req, MultiTokenizerWrapper):
1270
1271
1272
1273
                worker_id = recv_req.worker_id
                recv_req = recv_req.obj
                output = self._request_dispatcher(recv_req)
                if output is not None:
1274
                    output = MultiTokenizerWrapper(worker_id, output)
1275
1276
1277
                    self.send_to_tokenizer.send_pyobj(output)
                continue

1278
            output = self._request_dispatcher(recv_req)
1279
            if output is not None:
1280
1281
1282
1283
1284
                if isinstance(output, RpcReqOutput):
                    if self.recv_from_rpc is not None:
                        self.recv_from_rpc.send_pyobj(output)
                else:
                    self.send_to_tokenizer.send_pyobj(output)
1285

1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
    def init_req_max_new_tokens(self, req):
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
            self.max_req_len - len(req.origin_input_ids) - 1,
        )

1296
1297
1298
1299
    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
1300
        # Create a new request
1301
1302
1303
1304
1305
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
1306
1307
1308
1309
1310
1311
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

1312
1313
1314
1315
            if recv_req.bootstrap_port is None:
                # Use default bootstrap port
                recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port

1316
1317
1318
1319
1320
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
1321
1322
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
1323
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
1324
                stream=recv_req.stream,
1325
                lora_id=recv_req.lora_id,
Rin Intachuen's avatar
Rin Intachuen committed
1326
                input_embeds=recv_req.input_embeds,
Lianmin Zheng's avatar
Lianmin Zheng committed
1327
                custom_logit_processor=recv_req.custom_logit_processor,
1328
                return_hidden_states=recv_req.return_hidden_states,
1329
                eos_token_ids=self.model_config.hf_eos_token_id,
1330
                bootstrap_host=recv_req.bootstrap_host,
1331
                bootstrap_port=recv_req.bootstrap_port,
1332
                bootstrap_room=recv_req.bootstrap_room,
1333
                disagg_mode=self.disaggregation_mode,
1334
                data_parallel_rank=recv_req.data_parallel_rank,
1335
                vocab_size=self.model_config.vocab_size,
1336
                priority=recv_req.priority,
1337
1338
1339
                metrics_collector=(
                    self.metrics_collector if self.enable_metrics else None
                ),
1340
1341
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
1342

1343
1344
1345
            if self.disaggregation_mode != DisaggregationMode.NULL:
                # Invalid request for disaggregated mode
                if recv_req.bootstrap_room is None:
1346
                    error_msg = (
1347
1348
1349
                        f"Invalid request: Disaggregated request received without "
                        f"boostrap room id. {req.rid=}"
                    )
1350
                    logger.error(error_msg)
1351
                    prepare_abort(req, error_msg, status_code=HTTPStatus.BAD_REQUEST)
1352
1353
1354
                    self.stream_output([req], req.return_logprob)
                    return

1355
1356
1357
1358
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
1359
                req.set_finish_with_abort(
1360
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
1361
                )
1362
                self.init_req_max_new_tokens(req)
1363
                self._add_request_to_queue(req)
1364
1365
                return
        else:
1366
1367
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
1368
            req = session.create_req(recv_req, self.tokenizer)
1369
            if isinstance(req.finished_reason, FINISH_ABORT):
1370
                self.init_req_max_new_tokens(req)
1371
                self._add_request_to_queue(req)
1372
                return
1373

1374
        # Handle multimodal inputs
Mick's avatar
Mick committed
1375
1376
        if recv_req.mm_inputs is not None:
            image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
1377
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
1378
            req.origin_input_ids = self.pad_input_ids_func(
1379
                req.origin_input_ids, image_inputs
1380
            )
1381
            req.extend_image_inputs(image_inputs)
1382

1383
            if len(req.origin_input_ids) >= self.max_req_input_len:
1384
1385
1386
1387
1388
                req.set_finish_with_abort(
                    error_msg=(
                        "Multimodal prompt is too long after expanding multimodal tokens. "
                        f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                    )
1389
                )
1390
                self.init_req_max_new_tokens(req)
1391
                self._add_request_to_queue(req)
1392
1393
                return

1394
1395
1396
        # initialize before returning
        self.init_req_max_new_tokens(req)

1397
        # Validate prompt length
1398
1399
1400
1401
1402
1403
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
1404
            req.set_finish_with_abort(error_msg)
1405
            self._add_request_to_queue(req)
1406
            return
1407

1408
        # Copy more attributes
1409
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
1410
            # By default, only return the logprobs for output tokens
1411
1412
1413
1414
1415
1416
1417
            # For prefill-only requests with logprob_start_len == -1, set logprob_start_len beyond input sequence
            # to skip input logprob computation entirely
            if req.is_prefill_only:
                req.logprob_start_len = len(req.origin_input_ids)
            else:
                # TODO: For text generation, evaluate setting logprob_start_len to len(req.origin_input_ids) as well
                req.logprob_start_len = len(req.origin_input_ids) - 1
1418
1419
1420
        else:
            req.logprob_start_len = recv_req.logprob_start_len

1421
1422
1423
        if not req.is_prefill_only and req.logprob_start_len >= len(
            req.origin_input_ids
        ):
1424
            error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
1425
            req.logprob_start_len = len(req.origin_input_ids) - 1
1426
            req.set_finish_with_abort(error_msg)
1427
1428
1429
            self._add_request_to_queue(req)
            return

1430
1431
1432
1433
1434
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
1435
            or req.sampling_params.ebnf is not None
1436
            or req.sampling_params.structural_tag is not None
1437
1438
1439
1440
1441
1442
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
1443
1444
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
1445
1446
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
1447

1448
1449
1450
1451
1452
            value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
            req.grammar = value

            if not cache_hit:
                req.grammar_key = key
1453
                add_to_grammar_queue = True
1454
1455
1456
1457
            else:
                if value is INVALID_GRAMMAR_OBJ:  # We hit a cached invalid grammar.
                    error_msg = f"Invalid grammar request with cache hit: {key=}"
                    req.set_finish_with_abort(error_msg)
1458
1459

        if add_to_grammar_queue:
1460
1461
            self.grammar_queue.append(req)
        else:
1462
1463
            self._add_request_to_queue(req)

1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
    def handle_batch_generate_request(
        self,
        recv_req: BatchTokenizedGenerateReqInput,
    ):
        """Handle optimized batch generate request."""
        logger.debug(f"Processing batch generate request with {len(recv_req)} requests")

        # Process each request in the batch
        for tokenized_req in recv_req:
            self.handle_generate_request(tokenized_req)

1475
1476
1477
    def _prefetch_kvcache(self, req: Req):
        if self.enable_hicache_storage:
            req.init_next_round_input(self.tree_cache)
1478
1479
1480
1481
1482
            if req.last_node.backuped:
                # only to initiate the prefetch if the last node is backuped
                # otherwise, the allocated GPU memory must be locked for integrity
                last_hash = req.last_host_node.get_last_hash_value()
                matched_len = len(req.prefix_indices) + req.host_hit_length
1483
1484
1485
1486
1487
                new_input_tokens = req.fill_ids[matched_len:]
                self.tree_cache.prefetch_from_storage(
                    req.rid, req.last_host_node, new_input_tokens, last_hash
                )

1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
    def _add_request_to_queue(self, req: Req, is_retracted: bool = False):
        if self.disaggregation_mode == DisaggregationMode.NULL:
            self._set_or_validate_priority(req)
            if self._abort_on_queued_limit(req):
                return
            self._prefetch_kvcache(req)
            self.waiting_queue.append(req)
            req.time_stats.wait_queue_entry_time = time.perf_counter()
            trace_slice_end("process req", req.rid, auto_next_anon=True)
        elif self.disaggregation_mode == DisaggregationMode.PREFILL:
            self._prefetch_kvcache(req)
            self.disagg_prefill_bootstrap_queue.add(
                req, self.model_config.num_key_value_heads
Byron Hsu's avatar
Byron Hsu committed
1501
            )
1502
            req.time_stats.prefill_bootstrap_queue_entry_time = time.perf_counter()
1503
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
1504
1505
1506
            self.disagg_decode_prealloc_queue.add(req, is_retracted=is_retracted)
            if not is_retracted:
                req.time_stats.decode_prealloc_queue_entry_time = time.perf_counter()
Byron Hsu's avatar
Byron Hsu committed
1507
        else:
1508
            raise ValueError(f"Invalid {self.disaggregation_mode=}")
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523

    def _set_or_validate_priority(self, req: Req):
        """Set the default priority value, or abort the request based on the priority scheduling mode."""
        if self.enable_priority_scheduling and req.priority is None:
            if self.schedule_low_priority_values_first:
                req.priority = sys.maxsize
            else:
                req.priority = -sys.maxsize - 1
        elif not self.enable_priority_scheduling and req.priority is not None:
            abort_req = AbortReq(
                finished_reason={
                    "type": "abort",
                    "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
                    "message": "Using priority is disabled for this server. Please send a new request without a priority.",
                },
1524
                rid=req.rid,
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
            )
            self.send_to_tokenizer.send_pyobj(abort_req)

    def _abort_on_queued_limit(self, recv_req: Req) -> bool:
        """Abort an incoming or existing request if the waiting queue is full. Returns True if the incoming request is aborted."""
        if (
            self.max_queued_requests is None
            or len(self.waiting_queue) + 1 <= self.max_queued_requests
        ):
            return False

        # Reject the incoming request by default.
        req_to_abort = recv_req
        message = "The request queue is full."
        if self.enable_priority_scheduling:
            # With priority scheduling, consider aboritng an existing request based on the priority.
            # direction = 1  => smaller number = higher priority; -1 => larger number = higher priority.
            # max(...) + (direction * priority, queue_time_start) picks the least-preferred request.
            # Tie: later queue_time_start (newer) is evicted first. Preempt only if strictly better.
            direction = 1 if self.schedule_low_priority_values_first else -1
            key_fn = lambda item: (
                direction * item[1].priority,
1547
                item[1].time_stats.wait_queue_entry_time,
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
            )
            idx, candidate_req = max(enumerate(self.waiting_queue), key=key_fn)
            abort_existing_req = (
                direction * recv_req.priority < direction * candidate_req.priority
            )
            if abort_existing_req:
                self.waiting_queue.pop(idx)
                req_to_abort = candidate_req
                message = "The request is aborted by a higher priority request."

        self.send_to_tokenizer.send_pyobj(
            AbortReq(
                finished_reason={
                    "type": "abort",
                    "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
                    "message": message,
                },
1565
                rid=req_to_abort.rid,
1566
1567
1568
            )
        )
        return req_to_abort.rid == recv_req.rid
1569
1570
1571

    def handle_embedding_request(
        self,
1572
        recv_req: TokenizedEmbeddingReqInput,
1573
1574
1575
1576
1577
1578
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
woodx's avatar
woodx committed
1579
            token_type_ids=recv_req.token_type_ids,
1580
            priority=recv_req.priority,
1581
1582
1583
        )
        req.tokenizer = self.tokenizer

1584
1585
        # Handle multimodal inputs
        if recv_req.image_inputs is not None:
Mick's avatar
Mick committed
1586
            image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
1587
1588
1589
1590
1591
1592
1593
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
            req.origin_input_ids = self.pad_input_ids_func(
                req.origin_input_ids, image_inputs
            )
            req.extend_image_inputs(image_inputs)

            if len(req.origin_input_ids) >= self.max_req_input_len:
1594
1595
1596
1597
1598
                req.set_finish_with_abort(
                    error_msg=(
                        "Multimodal prompt is too long after expanding multimodal tokens. "
                        f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                    )
1599
                )
1600
                self._add_request_to_queue(req)
1601
1602
                return

1603
        # Validate prompts length
1604
        error_msg = validate_input_length(
1605
1606
1607
1608
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
1609
        if error_msg:
1610
            self._add_request_to_queue(req)
1611
            return
1612

1613
1614
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
1615
        self._add_request_to_queue(req)
1616

1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
    def handle_batch_embedding_request(
        self,
        recv_req: BatchTokenizedEmbeddingReqInput,
    ):
        """Handle optimized batch embedding request."""
        logger.debug(
            f"Processing batch embedding request with {len(recv_req)} requests"
        )

        # Process each request in the batch
        for tokenized_req in recv_req:
            self.handle_embedding_request(tokenized_req)

1630
1631
1632
1633
1634
    def self_check_during_idle(self):
        self.check_memory()
        self.check_tree_cache()
        self.new_token_ratio = self.init_new_token_ratio
        self.maybe_sleep_on_idle()
1635

Lianmin Zheng's avatar
Lianmin Zheng committed
1636
    def check_memory(self):
Hanming Lu's avatar
Hanming Lu committed
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
        if self.is_hybrid:
            (
                full_num_used,
                swa_num_used,
                _,
                _,
                full_available_size,
                full_evictable_size,
                swa_available_size,
                swa_evictable_size,
            ) = self._get_swa_token_info()
            memory_leak = full_num_used != 0 or swa_num_used != 0
            token_msg = (
                f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
                f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
            )
tarinkk's avatar
tarinkk committed
1653
        else:
Hanming Lu's avatar
Hanming Lu committed
1654
1655
1656
            _, _, available_size, evictable_size = self._get_token_info()
            protected_size = self.tree_cache.protected_size()
            memory_leak = (available_size + evictable_size) != (
1657
1658
1659
                # self.max_total_num_tokens
                # if not self.enable_hierarchical_cache
                # else self.max_total_num_tokens - protected_size
Hanming Lu's avatar
Hanming Lu committed
1660
                self.max_total_num_tokens
1661
                - protected_size
Lianmin Zheng's avatar
Lianmin Zheng committed
1662
            )
Hanming Lu's avatar
Hanming Lu committed
1663
1664
1665
1666
            token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"

        if memory_leak:
            msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
Lianmin Zheng's avatar
Lianmin Zheng committed
1667
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1668

1669
1670
1671
1672
1673
1674
1675
1676
        if self.disaggregation_mode == DisaggregationMode.DECODE:
            req_total_size = (
                self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
            )
        else:
            req_total_size = self.req_to_token_pool.size

        if len(self.req_to_token_pool.free_slots) != req_total_size:
1677
            msg = (
1678
                "req_to_token_pool memory leak detected!"
1679
1680
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1681
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1682
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1683

1684
1685
        if (
            self.enable_metrics
1686
            and self.current_scheduler_metrics_enabled()
1687
            and time.perf_counter() > self.metrics_collector.last_log_time + 30
1688
1689
        ):
            # During idle time, also collect metrics every 30 seconds.
Hanming Lu's avatar
Hanming Lu committed
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
            if self.is_hybrid:
                (
                    full_num_used,
                    swa_num_used,
                    full_token_usage,
                    swa_token_usage,
                    _,
                    _,
                    _,
                    _,
                ) = self._get_swa_token_info()
                num_used = max(full_num_used, swa_num_used)
                token_usage = max(full_token_usage, swa_token_usage)
            else:
                num_used, token_usage, _, _ = self._get_token_info()
Lianmin Zheng's avatar
Lianmin Zheng committed
1705
            num_running_reqs = len(self.running_batch.reqs)
1706
1707
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
Hanming Lu's avatar
Hanming Lu committed
1708
            self.stats.token_usage = round(token_usage, 2)
1709
1710
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
1711
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
            if self.disaggregation_mode == DisaggregationMode.PREFILL:
                self.stats.num_prefill_prealloc_queue_reqs = len(
                    self.disagg_prefill_bootstrap_queue.queue
                )
                self.stats.num_prefill_inflight_queue_reqs = len(
                    self.disagg_prefill_inflight_queue
                )
            if self.disaggregation_mode == DisaggregationMode.DECODE:
                self.stats.num_decode_prealloc_queue_reqs = len(
                    self.disagg_decode_prealloc_queue.queue
                )
                self.stats.num_decode_transfer_queue_reqs = len(
                    self.disagg_decode_transfer_queue.queue
                )
1726
            self.metrics_collector.log_stats(self.stats)
1727
        self._publish_kv_events()
1728

Hanming Lu's avatar
Hanming Lu committed
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
    def check_tree_cache(self):
        if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache):
            self.tree_cache.sanity_check()

    def _get_token_info(self):
        available_size = self.token_to_kv_pool_allocator.available_size()
        evictable_size = self.tree_cache.evictable_size()
        num_used = self.max_total_num_tokens - (available_size + evictable_size)
        token_usage = num_used / self.max_total_num_tokens
        return num_used, token_usage, available_size, evictable_size

    def _get_swa_token_info(self):
        full_available_size = self.token_to_kv_pool_allocator.full_available_size()
        full_evictable_size = self.tree_cache.full_evictable_size()
        swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
        swa_evictable_size = self.tree_cache.swa_evictable_size()
        full_num_used = self.full_tokens_per_layer - (
            full_available_size + full_evictable_size
        )
        swa_num_used = self.swa_tokens_per_layer - (
            swa_available_size + swa_evictable_size
        )
        full_token_usage = full_num_used / self.full_tokens_per_layer
        swa_token_usage = swa_num_used / self.swa_tokens_per_layer
        return (
            full_num_used,
            swa_num_used,
            full_token_usage,
            swa_token_usage,
            full_available_size,
            full_evictable_size,
            swa_available_size,
            swa_evictable_size,
        )

1764
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1765
        # Merge the prefill batch into the running batch
1766
1767
1768
1769
1770
        chunked_req_to_exclude = set()
        if self.chunked_req:
            # Move the chunked request out of the batch so that we can merge
            # only finished requests to running_batch.
            chunked_req_to_exclude.add(self.chunked_req)
1771
            self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
1772
            # chunked request keeps its rid but will get a new req_pool_idx
1773
            if self.tp_worker.worker.model_runner.mambaish_config is not None:
Yi Zhang's avatar
Yi Zhang committed
1774
1775
1776
1777
1778
                self.req_to_token_pool.free(
                    self.chunked_req.req_pool_idx, free_mamba_cache=False
                )
            else:
                self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
Lianmin Zheng's avatar
Lianmin Zheng committed
1779
        if self.last_batch and self.last_batch.forward_mode.is_extend():
1780
1781
1782
1783
            if self.last_batch.chunked_req is not None:
                # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
                # We need to discard it.
                chunked_req_to_exclude.add(self.last_batch.chunked_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1784

1785
            # Filter batch
1786
            last_bs = self.last_batch.batch_size()
1787
1788
1789
            self.last_batch.filter_batch(
                chunked_req_to_exclude=list(chunked_req_to_exclude)
            )
1790
            if self.last_batch.batch_size() < last_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1791
                self.running_batch.batch_is_full = False
1792

1793
1794
1795
            # Merge the new batch into the running batch.
            # For prefill-only batch, we can avoid going through decoding step.
            if not self.last_batch.is_empty() and not self.last_batch.is_prefill_only:
Lianmin Zheng's avatar
Lianmin Zheng committed
1796
                if self.running_batch.is_empty():
1797
1798
                    self.running_batch = self.last_batch
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1799
                    # Merge running_batch with prefill batch
1800
                    self.running_batch.merge_batch(self.last_batch)
1801

1802
        new_batch = self.get_new_batch_prefill()
1803

1804
1805
1806
1807
1808
        need_dp_attn_preparation = require_mlp_sync(self.server_args)

        if need_dp_attn_preparation and not self.spec_algorithm.is_none():
            # In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
            # We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
1809
            new_batch = self.prepare_mlp_sync_batch(new_batch)
1810
1811
1812
            need_dp_attn_preparation = new_batch is None

        if new_batch is not None:
1813
1814
1815
1816
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
1817
            if not self.running_batch.is_empty():
1818
                self.running_batch = self.update_running_batch(self.running_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1819
1820
1821
                ret = self.running_batch if not self.running_batch.is_empty() else None
            else:
                ret = None
1822

1823
1824
        # Handle DP attention
        if need_dp_attn_preparation:
1825
            ret = self.prepare_mlp_sync_batch(ret)
1826
1827

        return ret
1828

1829
    def get_num_allocatable_reqs(self, running_bs):
1830
        res = global_server_args_dict["pp_max_micro_batch_size"] - running_bs
1831
1832
1833
1834
        if self.pp_size > 1:
            res = min(res, self.req_to_token_pool.available_size())
        return res

Lianmin Zheng's avatar
Lianmin Zheng committed
1835
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
1836
        # Check if the grammar is ready in the grammar queue
1837
        if self.grammar_queue:
1838
            self.move_ready_grammar_requests()
1839

1840
1841
1842
1843
        if self.try_preemption:
            # Reset batch_is_full to try preemption with a prefill adder.
            self.running_batch.batch_is_full = False

Lianmin Zheng's avatar
Lianmin Zheng committed
1844
1845
        # Handle the cases where prefill is not allowed
        if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1846
            self.running_batch.batch_is_full or len(self.waiting_queue) == 0
1847
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1848
1849
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1850
        running_bs = len(self.running_batch.reqs)
1851
        # Ignore the check if self.chunked_req is not None.
1852
1853
1854
1855
        # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
        # as the space for the chunked request has just been released.
        # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
        # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
1856
1857
1858
1859
1860
        if (
            self.get_num_allocatable_reqs(running_bs) <= 0
            and not self.chunked_req
            and not self.try_preemption
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1861
            self.running_batch.batch_is_full = True
1862
1863
            return None

1864
        if self.enable_hierarchical_cache:
1865
            self.tree_cache.check_hicache_events()
1866

1867
        # Get priority queue
1868
        self.policy.calc_priority(self.waiting_queue)
1869

Lianmin Zheng's avatar
Lianmin Zheng committed
1870
        # Prefill policy
1871
        adder = PrefillAdder(
1872
            self.page_size,
1873
            self.tree_cache,
1874
            self.token_to_kv_pool_allocator,
1875
1876
1877
1878
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1879
            running_bs if self.is_mixed_chunk else 0,
1880
            self.priority_scheduling_preemption_threshold,
1881
1882
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1883
        if self.chunked_req is not None:
1884
1885
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1886

1887
        if self.enable_lora:
1888
            lora_set = set([req.lora_id for req in self.running_batch.reqs])
Lianmin Zheng's avatar
Lianmin Zheng committed
1889

1890
        # Get requests from the waiting queue to a new prefill batch
1891
        for req in self.waiting_queue:
1892
1893
1894
1895
1896

            if self.enable_lora and not self.tp_worker.can_run_lora_batch(
                lora_set
                | set([req.lora_id for req in adder.can_run_list])
                | set([req.lora_id])
1897
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1898
                self.running_batch.batch_is_full = True
1899
1900
                break

1901
            running_bs = len(self.running_batch.reqs) - len(adder.preempt_list)
1902
            if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
Lianmin Zheng's avatar
Lianmin Zheng committed
1903
                self.running_batch.batch_is_full = True
Byron Hsu's avatar
Byron Hsu committed
1904
1905
1906
1907
1908
            if self.disaggregation_mode == DisaggregationMode.PREFILL:
                # In prefill mode, prealloc queue and transfer queue can also take memory,
                # so we need to check if the available size for the actual available size.
                if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
                    self.running_batch.batch_is_full = True
1909
1910
1911
1912
1913

            if self.running_batch.batch_is_full:
                if not self.try_preemption:
                    break
                if not adder.preempt_to_schedule(req, self.server_args):
Byron Hsu's avatar
Byron Hsu committed
1914
1915
                    break

1916
            if self.enable_hicache_storage:
pansicheng's avatar
pansicheng committed
1917
1918
1919
1920
                prefetch_done = self.tree_cache.check_prefetch_progress(req.rid)
                if not prefetch_done:
                    # skip staging requests that are ongoing prefetch
                    continue
1921

1922
            req.init_next_round_input(self.tree_cache)
1923
1924
1925
1926
1927
            res = adder.add_one_req(
                req,
                has_chunked_req=(self.chunked_req is not None),
                truncation_align_size=self.truncation_align_size,
            )
1928

1929
1930
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1931
1932
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
Lianmin Zheng's avatar
Lianmin Zheng committed
1933
1934
                        self.running_batch.batch_is_full = len(
                            adder.can_run_list
1935
                        ) > 0 or (not self.running_batch.is_empty())
1936
                    else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1937
                        self.running_batch.batch_is_full = True
1938
1939
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1940
        # Update waiting queue
1941
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1942
1943
        if len(can_run_list) == 0:
            return None
1944
1945
1946
1947

        if self.enable_metrics:
            # only record queue time when enable_metrics is True to avoid overhead
            for req in can_run_list:
1948
                req.add_latency(RequestStage.PREFILL_WAITING)
1949

Lianmin Zheng's avatar
Lianmin Zheng committed
1950
1951
1952
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1953
        if adder.preempt_list:
1954
1955
            for req in adder.preempt_list:
                self._add_request_to_queue(req)
1956

1957
1958
1959
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1960

1961
1962
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1963

1964
        # Print stats
1965
        if self.current_scheduler_metrics_enabled():
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
            self.log_prefill_stats(adder, can_run_list, running_bs, 0)

        for req in can_run_list:
            if req.time_stats.forward_entry_time == 0:
                # Avoid update chunked request many times
                req.time_stats.forward_entry_time = time.perf_counter()
                if self.enable_metrics:
                    self.metrics_collector.observe_queue_time(
                        req.time_stats.get_queueing_time(),
                    )
1976

Lianmin Zheng's avatar
Lianmin Zheng committed
1977
        # Create a new batch
1978
1979
1980
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
1981
            self.token_to_kv_pool_allocator,
1982
            self.tree_cache,
1983
            self.model_config,
1984
            self.enable_overlap,
1985
            self.spec_algorithm,
1986
            chunked_req=self.chunked_req,
1987
        )
1988
1989
        if self.enable_hierarchical_cache:
            # todo (zhiqiang): disable cuda graph execution if hicache loading triggered
1990
1991
1992
            new_batch.hicache_consumer_index = (
                self.tree_cache.ready_to_load_host_cache()
            )
1993

1994
        new_batch.prepare_for_extend()
1995

Lianmin Zheng's avatar
Lianmin Zheng committed
1996
        # Mixed-style chunked prefill
1997
1998
        if (
            self.is_mixed_chunk
Lianmin Zheng's avatar
Lianmin Zheng committed
1999
            and not self.running_batch.is_empty()
2000
2001
2002
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
2003
2004
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
2005
                self.running_batch.prepare_for_decode()
2006
2007
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
2008
2009
2010
            self.running_batch = ScheduleBatch(
                reqs=[], batch_is_full=self.running_batch.batch_is_full
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
2011
2012
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
2013
2014
2015

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
2016
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
2017
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
2018
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
2019

2020
2021
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
2022
2023
            batch.batch_is_full = False
            return batch
2024

Lianmin Zheng's avatar
Lianmin Zheng committed
2025
        # Check if decode out of memory
2026
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
2027
            TEST_RETRACT and batch.batch_size() > 10
2028
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
2029
            old_ratio = self.new_token_ratio
2030
2031
2032
2033
            retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
                self.server_args
            )
            self.num_retracted_reqs = len(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
2034
            self.new_token_ratio = new_token_ratio
2035
2036
            for req in reqs_to_abort:
                self.send_to_tokenizer.send_pyobj(
2037
                    AbortReq(abort_reason=req.to_abort_message, rid=req.rid)
2038
                )
2039

Lianmin Zheng's avatar
Lianmin Zheng committed
2040
            logger.info(
2041
                "KV cache pool is full. Retract requests. "
2042
2043
2044
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#aborted_retracted_reqs: {len(reqs_to_abort)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
Lianmin Zheng's avatar
Lianmin Zheng committed
2045
            )
2046

2047
2048
            for req in retracted_reqs:
                self._add_request_to_queue(req, is_retracted=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
2049
2050
        else:
            self.new_token_ratio = max(
2051
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
2052
2053
2054
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
2055
        if batch.batch_size() < initial_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
2056
            batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
2057
2058

        # Update batch tensors
2059
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
2060
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
2061

2062
2063
2064
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
2065
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
2066
2067
        self.forward_ct += 1

2068
2069
        # Whether to run the profiler
        self._profile_batch_predicate(batch)
2070
2071
2072
2073
        if self.forward_sleep_time is not None:
            logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
            time.sleep(self.forward_sleep_time)

2074
        # Run forward
2075
        if self.is_generation:
2076
2077
2078

            batch_or_worker_batch = batch

2079
            if self.spec_algorithm.is_none():
2080
2081
                # FIXME(lsyin): remove this if and finally unify the abstraction
                batch_or_worker_batch = batch.get_model_worker_batch()
2082

2083
2084
2085
2086
2087
2088
2089
            if self.enable_overlap:
                # FIXME: remove this assert
                assert isinstance(batch_or_worker_batch, ModelWorkerBatch)
                model_worker_batch = batch_or_worker_batch
                self.record_batch_in_overlap(model_worker_batch)

                # Sampling info will be modified during forward
2090
                model_worker_batch.sampling_info = (
2091
2092
2093
2094
                    model_worker_batch.sampling_info.copy_for_forward()
                )

                bs = len(model_worker_batch.seq_lens)
2095
                future_indices = self.future_map.alloc_future_indices(bs)
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110

                with self.forward_stream_ctx:
                    self.forward_stream.wait_stream(self.default_stream)
                    self.future_map.resolve_future(model_worker_batch)
                    if batch.sampling_info.grammars is not None:
                        model_worker_batch.delay_sample_launch = True
                    batch_result = self.model_worker.forward_batch_generation(
                        batch_or_worker_batch
                    )
                    # FIXME(lsyin): maybe move this to forward_batch_generation
                    batch_result.copy_done = torch.get_device_module(
                        self.device
                    ).Event()
                    if not model_worker_batch.delay_sample_launch:
                        self.future_map.store_to_map(
2111
                            future_indices, batch_result.next_token_ids
2112
2113
2114
                        )
                        batch_result.copy_to_cpu()
                    else:
2115
                        batch_result.future_indices = future_indices
2116
2117

                # FIXME(lsyin): move this assignment elsewhere
2118
                maybe_future_next_token_ids = -future_indices.indices
2119
2120
2121
2122
2123
            else:
                batch_result = self.model_worker.forward_batch_generation(
                    batch_or_worker_batch
                )
                maybe_future_next_token_ids = batch_result.next_token_ids
2124
2125
2126

            if not self.spec_algorithm.is_none():
                # TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
2127
2128
                self.update_spec_metrics(
                    batch.batch_size(), batch_result.num_accepted_tokens
2129
2130
                )

2131
2132
2133
2134
2135
            # NOTE: maybe_future_next_token_ids is used in ScheduleBatch,
            #       which can probably be replaced by future_indices later [TODO(lsyin)].
            #       we shall still keep the original outputs, e.g. next_token_ids
            #       in the GenerationBatchOutput for processing after copy_done.
            batch.output_ids = maybe_future_next_token_ids
2136

2137
2138
2139
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
2140
            if batch.return_logprob or self.spec_algorithm.is_eagle():
2141
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
2142
2143
            else:
                extend_input_len_per_req = None
2144

2145
            if batch.return_logprob:
2146
2147
2148
2149
2150
2151
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_logprob_start_len_per_req = None

2152
2153
2154
            batch_result.extend_input_len_per_req = extend_input_len_per_req
            batch_result.extend_logprob_start_len_per_req = (
                extend_logprob_start_len_per_req
2155
            )
2156
            return batch_result
Lianmin Zheng's avatar
Lianmin Zheng committed
2157
2158
2159
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
2160
            ret = EmbeddingBatchResult(embeddings=embeddings)
2161
        return ret
Chayenne's avatar
Chayenne committed
2162

2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
    def launch_last_batch_sample_if_needed(
        self,
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
        if len(self.result_queue) == 0:
            return

        tmp_batch, tmp_result = self.result_queue.popleft()

        tmp_result: GenerationBatchResult
        if not tmp_result.delay_sample_launch:
            self.result_queue.appendleft((tmp_batch, tmp_result))
            return

        with self.forward_stream_ctx:
            self.forward_stream.wait_stream(self.default_stream)
            tmp_result.next_token_ids = self.model_worker.model_runner.sample(
                tmp_result.logits_output,
                tmp_result.forward_batch,
            )
2182
2183
            future_indices = tmp_result.future_indices
            self.future_map.store_to_map(future_indices, tmp_result.next_token_ids)
2184
2185
2186
            tmp_result.copy_to_cpu()
            self.result_queue.appendleft((tmp_batch, tmp_result))

2187
2188
2189
2190
2191
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
2192
        if batch.forward_mode.is_decode():
2193
            self.process_batch_result_decode(batch, result)
2194
2195
            if self.enable_trace:
                trace_slice_batch("decode loop", batch.reqs)
2196

2197
        elif batch.forward_mode.is_extend():
2198
            self.process_batch_result_prefill(batch, result)
2199
2200
2201
            if self.enable_trace:
                trace_slice_batch("prefill", batch.reqs)

2202
2203
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
2204
2205
                if result.copy_done is not None:
                    result.copy_done.synchronize()
Lianmin Zheng's avatar
Lianmin Zheng committed
2206

2207
2208
2209
        self.maybe_send_health_check_signal()

    def maybe_send_health_check_signal(self):
2210
2211
2212
2213
2214
2215
2216
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

2217
2218
    def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
        return self.prepare_mlp_sync_batch_raw(
2219
2220
2221
            local_batch,
            dp_size=self.server_args.dp_size,
            attn_tp_size=self.attn_tp_size,
2222
            tp_group=self.tp_group,
2223
2224
2225
2226
            get_idle_batch=self.get_idle_batch,
            disable_cuda_graph=self.server_args.disable_cuda_graph,
            spec_algorithm=self.spec_algorithm,
            speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
2227
            require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
2228
            disable_overlap_schedule=self.server_args.disable_overlap_schedule,
2229
2230
2231
        )

    @staticmethod
2232
    def prepare_mlp_sync_batch_raw(
2233
2234
2235
        local_batch: ScheduleBatch,
        dp_size,
        attn_tp_size: int,
2236
        tp_group,
2237
2238
2239
2240
        get_idle_batch,
        disable_cuda_graph: bool,
        spec_algorithm,
        speculative_num_draft_tokens,
2241
        require_mlp_tp_gather: bool,
2242
        disable_overlap_schedule: bool,
2243
    ):
2244
2245
2246
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
2247
            num_tokens_for_logprob = 0
2248
2249
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
2250
            num_tokens_for_logprob = num_tokens
2251
2252
        else:
            num_tokens = local_batch.extend_num_tokens
2253
            num_tokens_for_logprob = sum(
Lianmin Zheng's avatar
Lianmin Zheng committed
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
                [
                    # We should have at least 1 token for sample in every case.
                    max(extend_len - logprob_start_len, 1)
                    for logprob_start_len, extend_len in zip(
                        local_batch.extend_logprob_start_lens, local_batch.extend_lens
                    )
                ]
            )

        if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
            can_cuda_graph = 1
        else:
            can_cuda_graph = 0

        is_extend_in_batch = (
            local_batch.forward_mode.is_extend() if local_batch else False
        )
2271
2272

        tbo_preparer = TboDPAttentionPreparer()
2273
2274
2275
2276
2277
2278
        if disable_overlap_schedule:
            group = tp_group.device_group
            device = tp_group.device
        else:
            group = tp_group.cpu_group
            device = "cpu"
2279

Lianmin Zheng's avatar
Lianmin Zheng committed
2280
2281
2282
2283
        local_info = torch.tensor(
            [
                num_tokens,
                can_cuda_graph,
2284
                num_tokens_for_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
2285
                is_extend_in_batch,
2286
2287
2288
                *tbo_preparer.prepare_all_gather(
                    local_batch,
                ),
Lianmin Zheng's avatar
Lianmin Zheng committed
2289
2290
            ],
            dtype=torch.int64,
2291
            device=device,
Lianmin Zheng's avatar
Lianmin Zheng committed
2292
2293
        )
        global_info = torch.empty(
2294
            (dp_size, attn_tp_size, 6),
Lianmin Zheng's avatar
Lianmin Zheng committed
2295
            dtype=torch.int64,
2296
            device=device,
Lianmin Zheng's avatar
Lianmin Zheng committed
2297
        )
2298
        torch.distributed.all_gather_into_tensor(
Lianmin Zheng's avatar
Lianmin Zheng committed
2299
2300
            global_info.flatten(),
            local_info,
2301
            group=group,
2302
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
2303
2304
2305
2306
        global_num_tokens = global_info[:, 0, 0].tolist()
        can_cuda_graph = min(global_info[:, 0, 1].tolist())
        global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
        is_extend_in_batch = global_info[:, 0, 3].tolist()
2307

2308
2309
2310
2311
        tbo_split_seq_index, global_forward_mode = tbo_preparer.compute_output(
            global_info[:, :, 4:6]
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
2312
        if local_batch is None and max(global_num_tokens) > 0:
2313
            local_batch = get_idle_batch()
2314
2315

        if local_batch is not None:
2316
            # TODO: handle the case when moe_dense_tp_size != 1
2317
            if not require_mlp_tp_gather:
2318
2319
2320
2321
2322
2323
2324
                local_batch.global_num_tokens = [num_tokens]
                local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
            else:
                local_batch.global_num_tokens = global_num_tokens
                local_batch.global_num_tokens_for_logprob = (
                    global_num_tokens_for_logprob
                )
2325
            local_batch.is_extend_in_batch = any(is_extend_in_batch)
2326
2327
            local_batch.tbo_split_seq_index = tbo_split_seq_index
            local_batch.global_forward_mode = global_forward_mode
2328

2329
            # Check forward mode for cuda graph
2330
            if not disable_cuda_graph:
Lianmin Zheng's avatar
Lianmin Zheng committed
2331
                local_batch.can_run_dp_cuda_graph = can_cuda_graph
2332

2333
        return local_batch
2334
2335
2336
2337
2338

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
2339
            self.token_to_kv_pool_allocator,
2340
2341
2342
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
2343
            self.spec_algorithm,
2344
2345
2346
2347
        )
        idle_batch.prepare_for_idle()
        return idle_batch

2348
2349
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
2350

2351
        num_ready_reqs = 0
2352
        num_timeout_reqs = 0
2353
2354
        for req in self.grammar_queue:
            try:
2355
2356
2357
                if req.finished():  # It is aborted by AbortReq
                    num_ready_reqs += 1
                    continue
2358

2359
                req.grammar = req.grammar.result(timeout=0.03)
2360
2361
                self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
                if req.grammar is INVALID_GRAMMAR_OBJ:
2362
2363
2364
                    error_msg = f"Invalid grammar request: {req.grammar_key=}"
                    req.set_finish_with_abort(error_msg)

2365
2366
                num_ready_reqs += 1
            except futures._base.TimeoutError:
2367
                req.grammar_wait_ct += 1
2368
2369
                # NOTE(lianmin): this timeout is the waiting time of the above line. It is
                # not the waiting time from it enters the grammar queue.
2370
                if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
2371
                    num_timeout_reqs = 1
2372
2373
                break

2374
        if self.server_args.enable_dp_attention:
2375
2376
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
2377
        else:
2378
2379
2380
2381
2382
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
2383
            tensor = torch.tensor([num_ready_reqs, num_timeout_reqs], dtype=torch.int32)
2384
2385
2386
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
2387
            num_ready_reqs_max, num_timeout_reqs_max = tensor.tolist()
2388

2389
            for i in range(num_ready_reqs, num_ready_reqs_max):
2390
                req = self.grammar_queue[i]
2391
2392
                if req.finished():  # It is aborted by AbortReq
                    continue
2393
                req.grammar = req.grammar.result()
2394
2395
                self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
                if req.grammar is INVALID_GRAMMAR_OBJ:
2396
2397
                    error_msg = f"Invalid grammar request: {req.grammar_key=}"
                    req.set_finish_with_abort(error_msg)
2398
2399
2400
        else:
            num_ready_reqs_max = num_ready_reqs
            num_timeout_reqs_max = num_timeout_reqs
2401

2402
2403
2404
        for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
            req = self.grammar_queue[i]
            req.grammar.cancel()
2405
            self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
2406
2407
            error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
            req.set_finish_with_abort(error_msg)
2408

2409
        num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
2410

2411
2412
        for req in self.grammar_queue[:num_ready_reqs]:
            self._add_request_to_queue(req)
2413
2414
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

2415
2416
2417
    def watchdog_thread(self):
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
        self.watchdog_last_forward_ct = 0
2418
        self.watchdog_last_time = time.perf_counter()
2419
2420

        while True:
2421
            current = time.perf_counter()
2422
2423
2424
2425
2426
2427
2428
2429
2430
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if current > self.watchdog_last_time + self.watchdog_timeout:
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)

Lianmin Zheng's avatar
Lianmin Zheng committed
2431
2432
        if not disable_request_logging():
            # Print batch size and memory pool info to check whether there are de-sync issues.
Hanming Lu's avatar
Hanming Lu committed
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
            if self.is_hybrid:
                (
                    _,
                    _,
                    _,
                    _,
                    full_available_size,
                    full_evictable_size,
                    swa_available_size,
                    swa_evictable_size,
                ) = self._get_swa_token_info()
                info_msg = (
                    f"{full_available_size=}, "
                    f"{full_evictable_size=}, "
                    f"{swa_available_size=}, "
                    f"{swa_evictable_size=}, "
                )
            else:
                _, _, available_size, evictable_size = self._get_token_info()
                info_msg = f"{available_size=}, " f"{evictable_size=}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
2453
2454
2455
            logger.error(
                f"{self.cur_batch.batch_size()=}, "
                f"{self.cur_batch.reqs=}, "
Hanming Lu's avatar
Hanming Lu committed
2456
                f"{info_msg}"
Lianmin Zheng's avatar
Lianmin Zheng committed
2457
2458
            )

2459
        pyspy_dump_schedulers()
Lianmin Zheng's avatar
Lianmin Zheng committed
2460
        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
2461
2462
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
2463
2464

        # Wait for some time so that the parent process can print the error.
2465
2466
2467
        time.sleep(5)
        self.parent_process.send_signal(signal.SIGQUIT)

2468
2469
2470
    def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
        success = self.flush_cache()
        return FlushCacheReqOutput(success=success)
2471

2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
    def clear_hicache_storage_wrapped(self, recv_req: ClearHiCacheReqInput):
        if self.enable_hierarchical_cache:
            self.tree_cache.clear_storage_backend()
            logger.info("Hierarchical cache cleared successfully!")
            if_success = True
        else:
            logging.warning("Hierarchical cache is not enabled.")
            if_success = False
        return ClearHiCacheReqOutput(success=if_success)

2482
    def flush_cache(self):
2483
        """Flush the memory pool and cache."""
2484
2485
2486
2487
2488
        if (
            len(self.waiting_queue) == 0
            and self.running_batch.is_empty()
            and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
        ):
2489
2490
            self.cur_batch = None
            self.last_batch = None
2491
            self.tree_cache.reset()
2492
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
2493
                self.grammar_backend.reset()
2494
            self.req_to_token_pool.clear()
2495
            self.token_to_kv_pool_allocator.clear()
2496

2497
2498
            if self.draft_worker:
                self.draft_worker.clear_cache_pool()
2499
2500
2501
2502
2503

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
2504
2505
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
2506
2507
2508
2509
2510
2511
2512
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
2513
                f"#running-req: {len(self.running_batch.reqs)}"
2514
2515
2516
2517
            )
            if_success = False
        return if_success

2518
    def get_load(self, recv_req: GetLoadReqInput = None) -> GetLoadReqOutput:
Liangsheng Yin's avatar
Liangsheng Yin committed
2519
        # TODO(lsyin): use dynamically maintained num_waiting_tokens
2520

Hanming Lu's avatar
Hanming Lu committed
2521
        if self.is_hybrid:
2522
            num_tokens_full = (
Hanming Lu's avatar
Hanming Lu committed
2523
2524
2525
2526
                self.full_tokens_per_layer
                - self.token_to_kv_pool_allocator.full_available_size()
                - self.tree_cache.full_evictable_size()
            )
2527
            num_tokens_swa = (
Hanming Lu's avatar
Hanming Lu committed
2528
2529
2530
2531
                self.swa_tokens_per_layer
                - self.token_to_kv_pool_allocator.swa_available_size()
                - self.tree_cache.swa_evictable_size()
            )
2532
            num_tokens = max(num_tokens_full, num_tokens_swa)
Hanming Lu's avatar
Hanming Lu committed
2533
        else:
2534
            num_tokens = (
Hanming Lu's avatar
Hanming Lu committed
2535
2536
2537
2538
                self.max_total_num_tokens
                - self.token_to_kv_pool_allocator.available_size()
                - self.tree_cache.evictable_size()
            )
2539
2540
2541
2542

        # Tokens in waiting queue, bootstrap queue, prealloc queue
        num_tokens += sum(len(req.origin_input_ids) for req in self.waiting_queue)
        num_waiting_reqs = len(self.waiting_queue)
Liangsheng Yin's avatar
Liangsheng Yin committed
2543
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
2544
            num_tokens += sum(
Liangsheng Yin's avatar
Liangsheng Yin committed
2545
2546
2547
                len(req.origin_input_ids)
                for req in self.disagg_prefill_bootstrap_queue.queue
            )
2548
            num_waiting_reqs += len(self.disagg_prefill_bootstrap_queue.queue)
Liangsheng Yin's avatar
Liangsheng Yin committed
2549
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
2550
            num_tokens += sum(
Liangsheng Yin's avatar
Liangsheng Yin committed
2551
2552
2553
                len(req.req.origin_input_ids)
                for req in self.disagg_decode_prealloc_queue.queue
            )
2554
            num_waiting_reqs += len(self.disagg_decode_prealloc_queue.queue)
Liangsheng Yin's avatar
Liangsheng Yin committed
2555

2556
2557
2558
2559
2560
2561
        return GetLoadReqOutput(
            dp_rank=self.dp_rank,
            num_reqs=len(self.running_batch.reqs) + num_waiting_reqs,
            num_waiting_reqs=num_waiting_reqs,
            num_tokens=num_tokens,
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
2562

2563
2564
2565
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
2566
2567
2568
2569
2570
2571
2572
2573
2574
        ret["memory_usage"] = {
            "weight": round(
                self.tp_worker.worker.model_runner.weight_load_mem_usage, 2
            ),
            "kvcache": round(
                self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
            ),
            "token_capacity": int(self.max_total_num_tokens),
        }
2575

2576
2577
2578
        ret["memory_usage"]["graph"] = round(
            self.tp_worker.worker.model_runner.graph_mem_usage, 2
        )
2579

2580
2581
2582
2583
2584
2585
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )
        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
Liangsheng Yin's avatar
Liangsheng Yin committed
2586
2587

        return GetInternalStateReqOutput(internal_state=ret)
2588
2589
2590
2591
2592

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
2593
                "pp_max_micro_batch_size",
2594
2595
2596
2597
2598
2599
2600
2601
2602
2603
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
2604
            elif k == "pp_max_micro_batch_size" and (
2605
2606
2607
2608
2609
2610
2611
                v > self.max_running_requests // self.pp_size or v < 1
            ):
                logging.warning(
                    f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
                )
                if_success = False
                break
2612
2613
2614
2615
2616
2617
2618
2619
2620
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
2621
            logger.info(f"Global server args updated! {global_server_args_dict=}")
2622
2623
2624
2625
2626
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
    def handle_rpc_request(self, recv_req: RpcReqInput):
        # Handle RPC requests
        logger.info(
            f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
        )

        success = True
        exec = None
        try:
            func = getattr(self, recv_req.method)
            func(recv_req.parameters)
        except Exception as e:
            success = False
            exec = e
            logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")

        barrier()
        return RpcReqOutput(success, "" if not exec else str(exec))

2646
2647
    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
Lianmin Zheng's avatar
Lianmin Zheng committed
2648
        to_del = []
2649
        for i, req in enumerate(self.waiting_queue):
2650
            if recv_req.abort_all or req.rid.startswith(recv_req.rid):
Lianmin Zheng's avatar
Lianmin Zheng committed
2651
                to_del.append(i)
2652

Lianmin Zheng's avatar
Lianmin Zheng committed
2653
        # Sort in reverse order to avoid index issues when deleting
Lianmin Zheng's avatar
Lianmin Zheng committed
2654
        for i in reversed(to_del):
2655
2656
2657
            # Abort method 1: directly pop from the queue
            # This only works for requests that have not started anything.
            # We still need to send something back to TokenizerManager to clean up the state.
Lianmin Zheng's avatar
Lianmin Zheng committed
2658
            req = self.waiting_queue.pop(i)
2659
2660
2661
            if self.enable_hicache_storage:
                # to release prefetch events associated with the request
                self.tree_cache.release_aborted_request(req.rid)
2662
            self.send_to_tokenizer.send_pyobj(AbortReq(rid=req.rid))
2663
2664
2665
2666
            # For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
            if self.disaggregation_mode == DisaggregationMode.DECODE:
                self.tree_cache.cache_finished_req(req)

2667
            logger.debug(f"Abort queued request. {req.rid=}")
2668

2669
2670
2671
2672
2673
        # Delete the requests in the grammar queue
        for req in self.grammar_queue:
            # Abort method 2: call `set_finish_with_abort`
            # The request will still run one prefill forward pass.
            # In this case, we change the input_ids to be only one token to make this prefill cheap.
2674
            if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2675
                logger.debug(f"Abort grammar queue request. {req.rid=}")
2676
2677
                if req.grammar:
                    req.grammar.cancel()
2678
2679
                req.set_finish_with_abort("Aborted by AbortReq.")

2680
2681
2682
        # Delete requests not in the waiting queue when PD disaggregation is enabled
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            # Abort requests that have not yet been bootstrapped
2683
            for req in self.disagg_prefill_bootstrap_queue.queue:
2684
                if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2685
                    logger.debug(f"Abort bootstrap queue request. {req.rid=}")
2686
2687
2688
2689
                    if hasattr(req.disagg_kv_sender, "abort"):
                        req.disagg_kv_sender.abort()

            # Abort in-flight requests
2690
            for req in self.disagg_prefill_inflight_queue:
2691
                if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2692
                    logger.debug(f"Abort inflight queue request. {req.rid=}")
2693
2694
2695
2696
2697
                    if hasattr(req.disagg_kv_sender, "abort"):
                        req.disagg_kv_sender.abort()

        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            # Abort requests that have not yet finished preallocation
2698
            for decode_req in self.disagg_decode_prealloc_queue.queue:
2699
                if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
2700
                    logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
2701
2702
2703
2704
                    if hasattr(decode_req.kv_receiver, "abort"):
                        decode_req.kv_receiver.abort()

            # Abort requests waiting for kvcache to release tree cache
2705
            for decode_req in self.disagg_decode_transfer_queue.queue:
2706
                if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
2707
                    logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
2708
2709
2710
                    if hasattr(decode_req.kv_receiver, "abort"):
                        decode_req.kv_receiver.abort()

2711
        # Delete requests in the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
2712
2713
2714
2715
2716
2717
        if self.cur_batch is self.running_batch or self.cur_batch is None:
            reqs = self.running_batch.reqs
        else:
            reqs = self.running_batch.reqs + self.cur_batch.reqs

        for req in reqs:
2718
2719
2720
            if not req.finished() and (
                recv_req.abort_all or req.rid.startswith(recv_req.rid)
            ):
2721
2722
2723
                # Abort method 3: set `to_abort=True`
                # The request will still run one decode forward pass.
                # Then we reuse all existing code to clean up the KV cache allocation.
Lianmin Zheng's avatar
Lianmin Zheng committed
2724
2725
                logger.debug(f"Abort running request. {req.rid=}")
                req.to_abort = True
2726

2727
2728
2729
    def _pause_engine(self) -> Tuple[List[Req], int]:
        raise NotImplementedError()

2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
    def load_lora_adapter(
        self, recv_req: LoadLoRAAdapterReqInput
    ) -> LoadLoRAAdapterReqOutput:
        """In-place loading a new lora adapter from disk or huggingface."""

        result = self.tp_worker.load_lora_adapter(recv_req)
        return result

    def unload_lora_adapter(
        self, recv_req: UnloadLoRAAdapterReqInput
    ) -> UnloadLoRAAdapterReqOutput:
        """Unload the lora adapter."""

        result = self.tp_worker.unload_lora_adapter(recv_req)
        return result

2746
2747
2748
2749
    def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
        self.send_to_detokenizer.send_pyobj(recv_req)
        return recv_req

2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
    def init_weights_send_group_for_remote_instance(
        self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
    ):
        """Init the seed and client instance communication group."""
        success, message = self.tp_worker.init_weights_send_group_for_remote_instance(
            recv_req
        )
        return InitWeightsSendGroupForRemoteInstanceReqOutput(success, message)

    def send_weights_to_remote_instance(
        self, recv_req: SendWeightsToRemoteInstanceReqInput
    ):
        """Send the seed instance weights to the destination instance."""
        success, message = self.tp_worker.send_weights_to_remote_instance(recv_req)
        return SendWeightsToRemoteInstanceReqOutput(success, message)

2766
2767
2768
2769
2770
2771
2772
    def slow_down(self, recv_req: SlowDownReqInput):
        t = recv_req.forward_sleep_time
        if t is not None and t <= 0:
            t = None
        self.forward_sleep_time = t
        return SlowDownReqOutput()

2773
    def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
2774
2775
        action = recv_req.action
        if action == ExpertDistributionReqType.START_RECORD:
2776
            get_global_expert_distribution_recorder().start_record()
2777
        elif action == ExpertDistributionReqType.STOP_RECORD:
2778
            get_global_expert_distribution_recorder().stop_record()
2779
        elif action == ExpertDistributionReqType.DUMP_RECORD:
2780
            get_global_expert_distribution_recorder().dump_record()
2781
        else:
2782
            raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
2783
        return ExpertDistributionReqOutput()
2784

2785
    def open_session(self, recv_req: OpenSessionReqInput):
2786
2787
2788
2789
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
2790
            return OpenSessionReqOutput(session_id, False)
2791
        elif session_id is None:
2792
            logger.warning("session id is None, cannot open.")
2793
            return OpenSessionReqOutput(session_id, False)
2794
2795
2796
2797
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
2798
            return OpenSessionReqOutput(session_id, True)
2799
2800
2801
2802
2803
2804
2805
2806
2807

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

2808
2809
    def get_print_prefix(self):
        prefix = ""
2810
2811
        if self.attn_dp_rank is not None:
            prefix += f" DP{self.attn_dp_rank}"
2812
2813
2814
2815
2816
2817
        if self.server_args.tp_size > 1:
            prefix += f" TP{self.tp_rank}"
        if self.pp_size > 1:
            prefix += f" PP{self.pp_rank}"
        return prefix

2818
2819
    def current_scheduler_metrics_enabled(self):
        return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
2820

2821
2822
2823
    def maybe_sleep_on_idle(self):
        if self.idle_sleeper is not None:
            self.idle_sleeper.maybe_sleep()
2824

2825
2826
2827
2828
2829
2830
    def handle_freeze_gc(self, recv_req: FreezeGCReq):
        """Handle freeze_gc request: freeze scheduler's GC and forward to detokenizer."""
        freeze_gc("Scheduler")
        self.send_to_detokenizer.send_pyobj(recv_req)
        return None

2831

2832
2833
2834
2835
2836
2837
2838
class IdleSleeper:
    """
    In setups which have long inactivity periods it is desirable to reduce
    system power consumption when sglang does nothing. This would lead not only
    to power savings, but also to more CPU thermal headroom when a request
    eventually comes. This is important in cases when multiple GPUs are connected
    as each GPU would otherwise pin one thread at 100% CPU usage.
2839

2840
2841
2842
    The simplest solution is to use zmq.Poller on all sockets that may receive
    data that needs handling immediately.
    """
2843

2844
2845
    def __init__(self, sockets):
        self.poller = zmq.Poller()
2846
        self.last_empty_time = time.time()
2847
2848
2849
2850
2851
        for s in sockets:
            self.poller.register(s, zmq.POLLIN)

    def maybe_sleep(self):
        self.poller.poll(1000)
2852
2853
2854
2855
2856
2857
2858
        if (
            global_config.torch_empty_cache_interval > 0
            and time.time() - self.last_empty_time
            > global_config.torch_empty_cache_interval
        ):
            self.last_empty_time = time.time()
            torch.cuda.empty_cache()
2859

2860

2861
def is_health_check_generate_req(recv_req):
2862
2863
    rid = getattr(recv_req, "rid", None)
    return rid is not None and rid.startswith("HEALTH_CHECK")
2864

2865
2866

def is_work_request(recv_req):
2867
2868
2869
2870
2871
2872
2873
2874
2875
    return isinstance(
        recv_req,
        (
            TokenizedGenerateReqInput,
            TokenizedEmbeddingReqInput,
            BatchTokenizedGenerateReqInput,
            BatchTokenizedEmbeddingReqInput,
        ),
    )
2876
2877


2878
2879
2880
2881
2882
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
Cheng Wan's avatar
Cheng Wan committed
2883
    moe_ep_rank: int,
2884
    pp_rank: int,
2885
    dp_rank: Optional[int],
2886
    pipe_writer,
2887
):
2888
    # Generate the logger prefix
2889
    prefix = ""
2890
2891
2892
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
2893
2894
2895
2896
    if dp_rank is not None:
        prefix += f" DP{dp_rank}"
    if server_args.tp_size > 1:
        prefix += f" TP{tp_rank}"
Cheng Wan's avatar
Cheng Wan committed
2897
2898
    if server_args.ep_size > 1:
        prefix += f" EP{moe_ep_rank}"
2899
2900
    if server_args.pp_size > 1:
        prefix += f" PP{pp_rank}"
2901

2902
    # Config the process
2903
    setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
2904
    faulthandler.enable()
2905
    kill_itself_when_parent_died()
2906
    parent_process = psutil.Process().parent()
2907

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
2908
    # Configure the logger
2909
    configure_logger(server_args, prefix=prefix)
2910
    suppress_other_loggers()
2911

2912
    # Set cpu affinity to this gpu process
2913
2914
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
2915
2916
2917
2918
2919
2920
2921
2922
2923
    if (numa_node := server_args.numa_node) is not None:
        numa_bind_to_node(numa_node[gpu_id])

    # Set up tracing
    if server_args.enable_trace:
        process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
        if server_args.disaggregation_mode == "null":
            thread_label = "Scheduler"
            trace_set_thread_info(thread_label, tp_rank, dp_rank)
2924

2925
    # Create a scheduler and run the event loop
2926
    try:
Cheng Wan's avatar
Cheng Wan committed
2927
        scheduler = Scheduler(
2928
2929
2930
2931
2932
2933
2934
            server_args,
            port_args,
            gpu_id,
            tp_rank,
            moe_ep_rank,
            pp_rank,
            dp_rank,
Cheng Wan's avatar
Cheng Wan committed
2935
        )
2936
        pipe_writer.send(
Mick's avatar
Mick committed
2937
2938
2939
2940
2941
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2942
        )
Byron Hsu's avatar
Byron Hsu committed
2943

2944
        disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
Byron Hsu's avatar
Byron Hsu committed
2945
        if disaggregation_mode == DisaggregationMode.NULL:
2946
2947
2948
            if server_args.pp_size > 1:
                scheduler.event_loop_pp()
            elif scheduler.enable_overlap:
Byron Hsu's avatar
Byron Hsu committed
2949
2950
2951
2952
                scheduler.event_loop_overlap()
            else:
                scheduler.event_loop_normal()
        elif disaggregation_mode == DisaggregationMode.PREFILL:
2953
2954
2955
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_prefill()
            else:
2956
2957
2958
2959
                if server_args.pp_size > 1:
                    scheduler.event_loop_pp_disagg_prefill()
                else:
                    scheduler.event_loop_normal_disagg_prefill()
2960

Byron Hsu's avatar
Byron Hsu committed
2961
        elif disaggregation_mode == DisaggregationMode.DECODE:
2962
2963
2964
2965
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_decode()
            else:
                scheduler.event_loop_normal_disagg_decode()
Byron Hsu's avatar
Byron Hsu committed
2966

2967
    except Exception:
2968
2969
2970
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)