scheduler.py 97.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
import time
23
from collections import deque
Lianmin Zheng's avatar
Lianmin Zheng committed
24
from concurrent import futures
25
from dataclasses import dataclass
26
from http import HTTPStatus
27
from types import SimpleNamespace
28
from typing import Dict, List, Optional, Tuple, Union
29

30
import psutil
31
import setproctitle
32
import torch
33
import zmq
34
from torch.distributed import barrier
35

36
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
37
from sglang.srt.configs.model_config import ModelConfig
38
39
40
41
from sglang.srt.constrained.base_grammar_backend import (
    INVALID_GRAMMAR_OBJ,
    create_grammar_backend,
)
Byron Hsu's avatar
Byron Hsu committed
42
43
44
45
46
47
48
49
50
51
52
from sglang.srt.disaggregation.decode import (
    DecodePreallocQueue,
    DecodeTransferQueue,
    SchedulerDisaggregationDecodeMixin,
)
from sglang.srt.disaggregation.prefill import (
    PrefillBootstrapQueue,
    SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
53
    MetadataBuffers,
Byron Hsu's avatar
Byron Hsu committed
54
    ReqToMetadataIdxAllocator,
55
    TransferBackend,
56
    prepare_abort,
Byron Hsu's avatar
Byron Hsu committed
57
)
58
from sglang.srt.distributed import get_pp_group, get_world_group
fzyzcjy's avatar
fzyzcjy committed
59
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
xm:D's avatar
xm:D committed
60
61
62
63
64
from sglang.srt.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
65
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
66
67
68
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
    AbortReq,
69
    CloseSessionReqInput,
70
    ExpertDistributionReq,
71
    ExpertDistributionReqOutput,
72
73
    FlushCacheReqInput,
    FlushCacheReqOutput,
74
75
    GetInternalStateReq,
    GetInternalStateReqOutput,
76
    GetWeightsByNameReqInput,
77
    HealthCheckOutput,
78
    InitWeightsUpdateGroupReqInput,
79
80
    LoadLoRAAdapterReqInput,
    LoadLoRAAdapterReqOutput,
81
82
    OpenSessionReqInput,
    OpenSessionReqOutput,
83
    ProfileReq,
84
85
    ReleaseMemoryOccupationReqInput,
    ResumeMemoryOccupationReqInput,
86
87
    RpcReqInput,
    RpcReqOutput,
88
89
    SetInternalStateReq,
    SetInternalStateReqOutput,
90
91
    SlowDownReqInput,
    SlowDownReqOutput,
92
93
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
94
95
    UnloadLoRAAdapterReqInput,
    UnloadLoRAAdapterReqOutput,
Chayenne's avatar
Chayenne committed
96
    UpdateWeightFromDiskReqInput,
97
    UpdateWeightsFromDistributedReqInput,
98
    UpdateWeightsFromTensorReqInput,
99
)
100
from sglang.srt.managers.mm_utils import init_embedding_cache
101
102
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
Mick's avatar
Mick committed
103
    MultimodalInputs,
104
105
    Req,
    ScheduleBatch,
106
    global_server_args_dict,
107
)
108
109
110
111
112
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
fzyzcjy's avatar
fzyzcjy committed
113
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
114
115
116
117
from sglang.srt.managers.scheduler_metrics_mixin import (
    RECORD_STEP_TIME,
    SchedulerMetricsMixin,
)
118
119
120
from sglang.srt.managers.scheduler_output_processor_mixin import (
    SchedulerOutputProcessorMixin,
)
121
122
123
124
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
from sglang.srt.managers.scheduler_update_weights_mixin import (
    SchedulerUpdateWeightsMixin,
)
125
from sglang.srt.managers.session_controller import Session
126
from sglang.srt.managers.tp_worker import TpModelWorker
127
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
128
from sglang.srt.managers.utils import validate_input_length
tarinkk's avatar
tarinkk committed
129
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
130
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
131
from sglang.srt.mem_cache.radix_cache import RadixCache
Hanming Lu's avatar
Hanming Lu committed
132
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
Lianmin Zheng's avatar
Lianmin Zheng committed
133
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
134
from sglang.srt.reasoning_parser import ReasoningParser
135
from sglang.srt.server_args import PortArgs, ServerArgs
136
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
137
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
138
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
139
from sglang.srt.utils import (
140
    DeepEPMode,
141
    DynamicGradMode,
142
    broadcast_pyobj,
fzyzcjy's avatar
fzyzcjy committed
143
    configure_gc_logger,
144
    configure_logger,
Lianmin Zheng's avatar
Lianmin Zheng committed
145
    disable_request_logging,
146
    get_available_gpu_memory,
147
    get_bool_env_var,
148
    get_zmq_socket,
149
    is_cpu,
Lianmin Zheng's avatar
Lianmin Zheng committed
150
    kill_itself_when_parent_died,
151
    point_to_point_pyobj,
152
    pyspy_dump_schedulers,
153
154
    require_mlp_sync,
    require_mlp_tp_gather,
155
    set_gpu_proc_affinity,
156
157
158
    set_random_seed,
    suppress_other_loggers,
)
159
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
160
161
162

logger = logging.getLogger(__name__)

163
# Test retract decode for debugging purposes
164
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
165
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
166

167
168
_is_cpu = is_cpu()

169

170
171
@dataclass
class GenerationBatchResult:
172
173
174
    logits_output: Optional[LogitsProcessorOutput]
    pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
    next_token_ids: Optional[List[int]]
175
176
    extend_input_len_per_req: List[int]
    extend_logprob_start_len_per_req: List[int]
177
    bid: int
178
    can_run_cuda_graph: bool
179
180
181
182
183
184
185
186


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


Byron Hsu's avatar
Byron Hsu committed
187
188
class Scheduler(
    SchedulerOutputProcessorMixin,
189
190
191
    SchedulerUpdateWeightsMixin,
    SchedulerProfilerMixin,
    SchedulerMetricsMixin,
Byron Hsu's avatar
Byron Hsu committed
192
193
194
    SchedulerDisaggregationDecodeMixin,
    SchedulerDisaggregationPrefillMixin,
):
195
196
197
198
199
200
201
202
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
203
        pp_rank: int,
204
        dp_rank: Optional[int],
205
206
    ):
        # Parse args
207
        self.server_args = server_args
208
        self.tp_rank = tp_rank
209
        self.pp_rank = pp_rank
210
        self.dp_rank = dp_rank
211
        self.tp_size = server_args.tp_size
212
213
        self.pp_size = server_args.pp_size
        self.dp_size = server_args.dp_size
214
        self.schedule_policy = server_args.schedule_policy
215
        self.enable_lora = server_args.enable_lora
216
        self.max_loras_per_batch = server_args.max_loras_per_batch
217
        self.enable_overlap = not server_args.disable_overlap_schedule
218
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
219
        self.enable_metrics = server_args.enable_metrics
220
221
222
        self.enable_metrics_for_all_schedulers = (
            server_args.enable_metrics_for_all_schedulers
        )
223
        self.enable_kv_cache_events = server_args.kv_events_config is not None
224
        self.stream_interval = server_args.stream_interval
225
226
227
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
228
229
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
230
        self.enable_hicache_storage = server_args.hicache_storage_backend is not None
Lianmin Zheng's avatar
Lianmin Zheng committed
231
        self.page_size = server_args.page_size
232

233
        self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
234
235
236
237
238
239
240
241
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

242
243
        # Init inter-process communication
        context = zmq.Context(2)
244
245
        self.idle_sleeper = None

246
        if self.pp_rank == 0 and self.attn_tp_rank == 0:
247
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
248
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
249
            )
250
251
252
253
            self.recv_from_rpc = get_zmq_socket(
                context, zmq.DEALER, port_args.rpc_ipc_name, False
            )

254
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
255
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
256
            )
257
            if server_args.skip_tokenizer_init:
258
                # Directly send to the TokenizerManager
259
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
260
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
261
262
                )
            else:
263
                # Send to the DetokenizerManager
264
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
265
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
266
                )
267

268
269
270
271
272
273
274
            if self.server_args.sleep_on_idle:
                self.idle_sleeper = IdleSleeper(
                    [
                        self.recv_from_tokenizer,
                        self.recv_from_rpc,
                    ]
                )
275
        else:
276
            self.recv_from_tokenizer = None
277
            self.recv_from_rpc = None
278
279
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
280

281
282
283
284
285
        if self.current_scheduler_metrics_enabled():
            self.send_metrics_from_scheduler = get_zmq_socket(
                context, zmq.PUSH, port_args.metrics_ipc_name, False
            )

286
        # Init tokenizer
287
        self.init_tokenizer()
288

289
290
291
292
293
294
295
296
297
        # Set reasoning_parser and think_end_id if --reasoning_parser is enabled
        if self.server_args.reasoning_parser and self.tokenizer:
            reasoning_parser = ReasoningParser(
                model_type=self.server_args.reasoning_parser, stream_reasoning=False
            )
            self.tokenizer.think_end_id = self.tokenizer.encode(
                reasoning_parser.detector.think_end_token, add_special_tokens=False
            )[0]

298
299
300
301
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
302

303
        # Launch a tensor parallel worker
304
        if self.enable_overlap:
305
            TpWorkerClass = TpModelWorkerClient
306
307
        else:
            TpWorkerClass = TpModelWorker
308

309
        self.tp_worker = TpWorkerClass(
310
            server_args=server_args,
311
312
            gpu_id=gpu_id,
            tp_rank=tp_rank,
313
            pp_rank=pp_rank,
314
            dp_rank=dp_rank,
315
            nccl_port=port_args.nccl_port,
316
        )
317

318
        # Launch a draft worker for speculative decoding
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
        else:
            self.draft_worker = None

333
        # Get token and memory info from the model worker
334
335
336
337
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
338
            self.max_queued_requests,
339
            self.max_req_len,
340
341
            self.max_req_input_len,
            self.random_seed,
342
            self.device,
343
344
345
346
347
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
348
349
350
351
352
353
354
355
        if global_server_args_dict["max_micro_batch_size"] is None:
            global_server_args_dict["max_micro_batch_size"] = max(
                self.max_running_requests // server_args.pp_size, 1
            )

        self.tp_group = self.tp_worker.get_tp_group()
        self.tp_cpu_group = self.tp_group.cpu_group
        self.attn_tp_group = self.tp_worker.get_attention_tp_group()
356
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
357
358
359
        self.pp_group = get_pp_group()
        self.world_group = get_world_group()

360
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
361
        global_server_args_dict.update(worker_global_server_args_dict)
362
        set_random_seed(self.random_seed)
363

364
        # Hybrid memory pool
Hanming Lu's avatar
Hanming Lu committed
365
366
367
368
369
370
371
        self.is_hybrid = self.tp_worker.is_hybrid
        if self.is_hybrid:
            self.sliding_window_size = self.tp_worker.sliding_window_size
            self.full_tokens_per_layer, self.swa_tokens_per_layer = (
                self.tp_worker.get_tokens_per_layer_info()
            )

372
        # Print debug info
373
        if tp_rank == 0:
374
375
376
            avail_mem = get_available_gpu_memory(
                self.device, self.gpu_id, empty_cache=False
            )
377
378
379
380
381
            logger.info(
                f"max_total_num_tokens={self.max_total_num_tokens}, "
                f"chunked_prefill_size={server_args.chunked_prefill_size}, "
                f"max_prefill_tokens={self.max_prefill_tokens}, "
                f"max_running_requests={self.max_running_requests}, "
382
383
                f"context_len={self.model_config.context_len}, "
                f"available_gpu_mem={avail_mem:.2f} GB"
384
            )
385

Lianmin Zheng's avatar
Lianmin Zheng committed
386
        # Init memory pool and cache
387
        self.init_memory_pool_and_cache()
388
389
390

        # Init running status
        self.waiting_queue: List[Req] = []
391
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
392
        self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
393
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
394
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
395
        # The last forward batch
396
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
397
398
        self.forward_ct = 0
        self.forward_ct_decode = 0
399
        self.num_generated_tokens = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
400
        self.last_prefill_tokens = 0
401
402
        self.last_decode_stats_tic = time.perf_counter()
        self.last_prefill_stats_tic = time.perf_counter()
403
        self.return_health_check_ct = 0
404
405
406
407
408
        self.num_retracted_reqs: int = 0
        self.num_paused_reqs: int = 0
        self.kv_transfer_speed_gb_s: float = 0.0
        self.kv_transfer_latency_ms: float = 0.0
        self.sessions: Dict[str, Session] = {}
409
        self.current_stream = torch.get_device_module(self.device).current_stream()
410
411
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
412
        self.forward_sleep_time = None
413

414
415
        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
416
417
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
418
        self.chunked_req = None
419
420
421
422
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
423
        # Init the grammar backend for constrained generation
424
        self.grammar_queue: List[Req] = []
425
        if not server_args.skip_tokenizer_init:
426
            self.grammar_backend = create_grammar_backend(
427
428
429
430
                server_args,
                self.tokenizer,
                self.model_config.vocab_size,
                self.model_config.hf_eos_token_id,
431
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
432
433
        else:
            self.grammar_backend = None
434

435
        # Init schedule policy and new token estimation
436
        self.policy = SchedulePolicy(
Lianmin Zheng's avatar
Lianmin Zheng committed
437
438
439
            self.schedule_policy,
            self.tree_cache,
            self.enable_hierarchical_cache,
440
        )
441
442
443
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
444
445
        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
446
447
            * server_args.schedule_conservativeness,
            1.0,
448
        )
449
450
451
452
453
454
455
456
457
458
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
459
460
461
462
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
463
        self.parent_process = psutil.Process().parent()
464
465

        # Init memory saver, profiler and metric stats
466
467
468
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )
469
        self.init_profier()
470

fzyzcjy's avatar
fzyzcjy committed
471
472
473
474
475
476
        self.input_blocker = (
            SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
            if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
            else None
        )

477
        # Init metrics stats
478
        self.init_metrics(tp_rank, pp_rank, dp_rank)
479
        self.init_kv_events(server_args.kv_events_config)
480

481
482
483
484
485
486
487
488
489
        # Init disaggregation
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
        self.init_disaggregation()

        if get_bool_env_var("SGLANG_GC_LOG"):
            configure_gc_logger()

490
491
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
492
493
494
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
495
                (FlushCacheReqInput, self.flush_cache_wrapped),
496
                (AbortReq, self.abort_request),
497
498
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
499
500
501
502
503
504
505
506
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
507
508
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
509
                (SlowDownReqInput, self.slow_down),
510
                (ProfileReq, self.profile),
511
                (GetInternalStateReq, self.get_internal_state),
512
                (SetInternalStateReq, self.set_internal_state),
513
                (RpcReqInput, self.handle_rpc_request),
514
                (ExpertDistributionReq, self.expert_distribution_handle),
515
516
                (LoadLoRAAdapterReqInput, self.load_lora_adapter),
                (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
517
518
519
            ]
        )

520
521
    def init_tokenizer(self):
        server_args = self.server_args
Lianmin Zheng's avatar
Lianmin Zheng committed
522

523
        self.model_config = ModelConfig.from_server_args(server_args)
524
        self.is_generation = self.model_config.is_generation
525

526
527
528
529
530
531
532
533
534
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if self.model_config.is_multimodal:
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
535
                    use_fast=not server_args.disable_fast_image_processor,
536
                )
xm:D's avatar
xm:D committed
537
                self.tokenizer = get_tokenizer_from_processor(self.processor)
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )

    def init_memory_pool_and_cache(self):
        server_args = self.server_args

        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            self.tp_worker.get_memory_pool()
        )

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
Hanming Lu's avatar
Hanming Lu committed
557
            if self.is_hybrid:
tarinkk's avatar
tarinkk committed
558
559
560
561
                ChunkCacheClass = SWAChunkCache
            else:
                ChunkCacheClass = ChunkCache
            self.tree_cache = ChunkCacheClass(
562
563
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
564
                page_size=self.page_size,
565
566
567
568
569
570
            )
        else:
            if self.enable_hierarchical_cache:
                self.tree_cache = HiRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
571
572
573
574
575
                    tp_cache_group=(
                        self.attn_tp_cpu_group
                        if self.server_args.enable_dp_attention
                        else self.tp_cpu_group
                    ),
576
                    page_size=self.page_size,
577
                    hicache_ratio=server_args.hicache_ratio,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
578
579
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
580
581
582
583
584
585
                    hicache_io_backend=(
                        "direct"
                        if server_args.attention_backend
                        == "fa3"  # hot fix for incompatibility
                        else server_args.hicache_io_backend
                    ),
586
                    hicache_storage_backend=server_args.hicache_storage_backend,
587
                )
588
589
590
                self.tp_worker.register_hicache_layer_transfer_counter(
                    self.tree_cache.cache_controller.layer_done_counter
                )
Hanming Lu's avatar
Hanming Lu committed
591
592
593
594
595
596
597
598
599
600
601
            elif self.is_hybrid:
                assert (
                    self.server_args.disaggregation_mode == "null"
                ), "Hybrid mode does not support disaggregation yet"
                self.tree_cache = SWARadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                    sliding_window_size=self.sliding_window_size,
                    page_size=self.page_size,
                    disable=server_args.disable_radix_cache,
                )
602

603
604
605
606
            else:
                self.tree_cache = RadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
607
                    page_size=self.page_size,
608
                    disable=server_args.disable_radix_cache,
609
                    enable_kv_cache_events=self.enable_kv_cache_events,
610
611
612
613
614
615
616
617
618
619
620
621
                )

        self.decode_mem_cache_buf_multiplier = (
            1
            if self.spec_algorithm.is_none()
            else (
                server_args.speculative_num_draft_tokens
                + (
                    server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                )
            )
622
        )
623

624
625
626
        embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
        init_embedding_cache(embedding_cache_size * 1024 * 1024)

Byron Hsu's avatar
Byron Hsu committed
627
    def init_disaggregation(self):
628
629
630
631
        self.transfer_backend = TransferBackend(
            self.server_args.disaggregation_transfer_backend
        )

Byron Hsu's avatar
Byron Hsu committed
632
633
634
635
        if (
            self.disaggregation_mode == DisaggregationMode.DECODE
        ):  # *2 for the headroom.
            buffer_size = (self.req_to_token_pool.size) * 2
Byron Hsu's avatar
Byron Hsu committed
636
            self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
Byron Hsu's avatar
Byron Hsu committed
637
638
                buffer_size
            )
639
640
            self.disagg_metadata_buffers = MetadataBuffers(
                buffer_size,
641
642
                hidden_size=self.model_config.hf_text_config.hidden_size,
                dtype=self.model_config.dtype,
643
644
                custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
            )
Byron Hsu's avatar
Byron Hsu committed
645
646
647

            # The decode requests polling kv cache
            self.disagg_decode_transfer_queue = DecodeTransferQueue(
648
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
649
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
650
                tp_rank=self.tp_rank,
651
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
652
653
                scheduler=self,
                tree_cache=self.tree_cache,
Byron Hsu's avatar
Byron Hsu committed
654
655
656
657
658
659
            )

            # The decode requests pending for pre-allocation
            self.disagg_decode_prealloc_queue = DecodePreallocQueue(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Byron Hsu's avatar
Byron Hsu committed
660
661
662
663
664
                draft_token_to_kv_pool=(
                    None
                    if self.draft_worker is None
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
665
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
666
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
667
668
669
                scheduler=self,
                transfer_queue=self.disagg_decode_transfer_queue,
                tree_cache=self.tree_cache,
670
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
671
672
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
673
674
                dp_size=self.server_args.dp_size,
                gpu_id=self.gpu_id,
Byron Hsu's avatar
Byron Hsu committed
675
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
676
677
                max_total_num_tokens=self.max_total_num_tokens,
                prefill_pp_size=self.server_args.disaggregation_prefill_pp,
678
                num_reserved_decode_tokens=self.server_args.num_reserved_decode_tokens,
679
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
680
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
681

Byron Hsu's avatar
Byron Hsu committed
682
683
684
        elif self.disaggregation_mode == DisaggregationMode.PREFILL:
            # *2 for the headroom.
            buffer_size = self.max_running_requests * 2
Byron Hsu's avatar
Byron Hsu committed
685
            self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
Byron Hsu's avatar
Byron Hsu committed
686
687
                buffer_size
            )
688
689
            self.disagg_metadata_buffers = MetadataBuffers(
                buffer_size,
690
691
                hidden_size=self.model_config.hf_text_config.hidden_size,
                dtype=self.model_config.dtype,
692
693
                custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
            )
Byron Hsu's avatar
Byron Hsu committed
694

Liangsheng Yin's avatar
Liangsheng Yin committed
695
            self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
Byron Hsu's avatar
Byron Hsu committed
696
                token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
Byron Hsu's avatar
Byron Hsu committed
697
698
699
700
701
                draft_token_to_kv_pool=(
                    None
                    if self.draft_worker is None
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
702
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
703
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
704
705
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
Byron Hsu's avatar
Byron Hsu committed
706
                gpu_id=self.gpu_id,
Byron Hsu's avatar
Byron Hsu committed
707
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
708
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
709
710
711
                max_total_num_tokens=self.max_total_num_tokens,
                decode_tp_size=self.server_args.disaggregation_decode_tp,
                decode_dp_size=self.server_args.disaggregation_decode_dp,
712
                scheduler=self,
Byron Hsu's avatar
Byron Hsu committed
713
714
715
                pp_rank=self.pp_rank,
                pp_size=self.pp_size,
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
716
717
            )
            # The prefill requests that are in the middle of kv sending
718
            self.disagg_prefill_inflight_queue: List[Req] = []
Byron Hsu's avatar
Byron Hsu committed
719

720
    @DynamicGradMode()
721
    def event_loop_normal(self):
722
        """A normal scheduler loop."""
723
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
724
725
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
726

727
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
728
            self.cur_batch = batch
729
730
731
732

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
733
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
734
                # When the server is idle, do self-check and re-init some states
735
                self.self_check_during_idle()
736
737

            self.last_batch = batch
738

739
    @DynamicGradMode()
Lianmin Zheng's avatar
Lianmin Zheng committed
740
    def event_loop_overlap(self):
741
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
742
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
743
744
745
746
747
748
749

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
750

Lianmin Zheng's avatar
Lianmin Zheng committed
751
            if batch:
752
                batch.launch_done = threading.Event()
Lianmin Zheng's avatar
Lianmin Zheng committed
753
                result = self.run_batch(batch)
754
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
755

756
                if self.last_batch is None:
757
                    # Create a dummy first batch to start the pipeline for overlap schedule.
758
759
760
761
762
763
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
764
                    self.process_batch_result(tmp_batch, None, batch.launch_done)
765

Lianmin Zheng's avatar
Lianmin Zheng committed
766
            if self.last_batch:
767
                # Process the results of the last batch
768
                tmp_batch, tmp_result = self.result_queue.popleft()
769
770
771
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
772
773
774
775
                # NOTE: we should use current launched batch's launch_done event Instead of the last batch's
                self.process_batch_result(
                    tmp_batch, tmp_result, batch.launch_done if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
776
            elif batch is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
777
                # When the server is idle, do self-check and re-init some states
778
                self.self_check_during_idle()
Lianmin Zheng's avatar
Lianmin Zheng committed
779
780
781

            self.last_batch = batch

782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
    @DynamicGradMode()
    def event_loop_pp(self):
        """A non-overlap scheduler loop for pipeline parallelism."""
        mbs = [None] * self.pp_size
        last_mbs = [None] * self.pp_size
        self.running_mbs = [
            ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
        ]
        bids = [None] * self.pp_size
        pp_outputs: Optional[PPProxyTensors] = None
        while True:
            server_is_idle = True
            for mb_id in range(self.pp_size):
                self.running_batch = self.running_mbs[mb_id]
                self.last_batch = last_mbs[mb_id]

                recv_reqs = self.recv_requests()
                self.process_input_requests(recv_reqs)
                mbs[mb_id] = self.get_next_batch_to_run()
                self.running_mbs[mb_id] = self.running_batch

                self.cur_batch = mbs[mb_id]
                if self.cur_batch:
                    server_is_idle = False
                    result = self.run_batch(self.cur_batch)

808
                # (last rank) send the outputs to the next step
809
810
811
812
813
814
                if self.pp_group.is_last_rank:
                    if self.cur_batch:
                        next_token_ids, bids[mb_id] = (
                            result.next_token_ids,
                            result.bid,
                        )
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
                        if self.cur_batch.return_logprob:
                            pp_outputs = PPProxyTensors(
                                {
                                    "next_token_ids": next_token_ids,
                                    "extend_input_len_per_req": result.extend_input_len_per_req,
                                    "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
                                }
                                | (
                                    {
                                        f"logits_output.{k}": v
                                        for k, v in result.logits_output.__dict__.items()
                                    }
                                    if result.logits_output is not None
                                    else {}
                                )
                            )
                        else:
                            pp_outputs = PPProxyTensors(
                                {
                                    "next_token_ids": next_token_ids,
                                }
                            )
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
                        # send the output from the last round to let the next stage worker run post processing
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                # receive outputs and post-process (filter finished reqs) the coming microbatch
                next_mb_id = (mb_id + 1) % self.pp_size
                next_pp_outputs = None
                if mbs[next_mb_id] is not None:
                    next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
                        self.pp_group.recv_tensor_dict(
                            all_gather_group=self.attn_tp_group
                        )
                    )
                    mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
853
854
855
856
857
858
859
860
861
                    logits_output_args = {
                        k[len("logits_output.") :]: v
                        for k, v in next_pp_outputs.tensors.items()
                        if k.startswith("logits_output.")
                    }
                    if len(logits_output_args) > 0:
                        logits_output = LogitsProcessorOutput(**logits_output_args)
                    else:
                        logits_output = None
862
                    output_result = GenerationBatchResult(
863
                        logits_output=logits_output,
864
865
                        pp_hidden_states_proxy_tensors=None,
                        next_token_ids=next_pp_outputs["next_token_ids"],
866
867
868
869
870
871
                        extend_input_len_per_req=next_pp_outputs.tensors.get(
                            "extend_input_len_per_req", None
                        ),
                        extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
                            "extend_logprob_start_len_per_req", None
                        ),
872
                        bid=bids[next_mb_id],
873
                        can_run_cuda_graph=result.can_run_cuda_graph,
874
875
876
877
                    )
                    self.process_batch_result(mbs[next_mb_id], output_result)
                    last_mbs[next_mb_id] = mbs[next_mb_id]

878
                # (not last rank)
879
880
881
                if not self.pp_group.is_last_rank:
                    if self.cur_batch:
                        bids[mb_id] = result.bid
882
883
                    # carry the outputs to the next stage
                    # send the outputs from the last round to let the next stage worker run post processing
884
885
886
887
888
889
890
                    if pp_outputs:
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                    # send out reqs to the next stage
891
                    dp_offset = self.attn_dp_rank * self.attn_tp_size
892
893
894
895
                    if self.attn_tp_rank == 0:
                        point_to_point_pyobj(
                            recv_reqs,
                            self.pp_rank * self.tp_size + dp_offset,
896
                            self.world_group.device_group,
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
                            self.pp_rank * self.tp_size + dp_offset,
                            (self.pp_rank + 1) * self.tp_size + dp_offset,
                        )

                    # send out proxy tensors to the next stage
                    if self.cur_batch:
                        self.pp_group.send_tensor_dict(
                            result.pp_hidden_states_proxy_tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                pp_outputs = next_pp_outputs

            # When the server is idle, self-check and re-init some states
            if server_is_idle:
912
913
                # When the server is idle, do self-check and re-init some states
                self.self_check_during_idle()
914

915
916
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
        if self.pp_rank == 0:
            if self.attn_tp_rank == 0:
                recv_reqs = []

                while True:
                    try:
                        recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_req)

                while True:
                    try:
                        recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_rpc)
            else:
                recv_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
936
        else:
937
            if self.attn_tp_rank == 0:
938
                dp_offset = self.attn_dp_rank * self.attn_tp_size
939
940
941
                recv_reqs = point_to_point_pyobj(
                    [],
                    self.pp_rank * self.tp_size + dp_offset,
942
                    self.world_group.device_group,
943
944
945
946
947
                    (self.pp_rank - 1) * self.tp_size + dp_offset,
                    self.pp_rank * self.tp_size + dp_offset,
                )
            else:
                recv_reqs = None
948

fzyzcjy's avatar
fzyzcjy committed
949
950
951
        if self.input_blocker is not None:
            recv_reqs = self.input_blocker.handle(recv_reqs)

952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                work_reqs = broadcast_pyobj(
                    work_reqs,
975
                    self.attn_tp_group.rank,
976
                    self.attn_tp_cpu_group,
977
                    src=self.attn_tp_group.ranks[0],
978
979
980
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
981
982
983
984
                    control_reqs,
                    self.tp_group.rank,
                    self.tp_cpu_group,
                    src=self.tp_group.ranks[0],
985
986
987
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
988
989
990
991
992
993
            recv_reqs = broadcast_pyobj(
                recv_reqs,
                self.tp_group.rank,
                self.tp_cpu_group,
                src=self.tp_group.ranks[0],
            )
994
995
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
996
    def process_input_requests(self, recv_reqs: List):
997
        for recv_req in recv_reqs:
998
999
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
Lianmin Zheng's avatar
Lianmin Zheng committed
1000
                self.chunked_req is not None or not self.running_batch.is_empty()
1001
1002
1003
1004
            ):
                self.return_health_check_ct += 1
                continue

1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
            # If it is a work request, accept or reject the request based on the request queue size.
            if is_work_request(recv_req):
                if len(self.waiting_queue) + 1 > self.max_queued_requests:
                    abort_req = AbortReq(
                        recv_req.rid,
                        finished_reason={
                            "type": "abort",
                            "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
                            "message": "The request queue is full.",
                        },
                    )
                    self.send_to_tokenizer.send_pyobj(abort_req)
                    continue
1018
            output = self._request_dispatcher(recv_req)
1019
            if output is not None:
1020
1021
1022
1023
1024
                if isinstance(output, RpcReqOutput):
                    if self.recv_from_rpc is not None:
                        self.recv_from_rpc.send_pyobj(output)
                else:
                    self.send_to_tokenizer.send_pyobj(output)
1025
1026
1027
1028
1029

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
1030
        # Create a new request
1031
1032
1033
1034
1035
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
1036
1037
1038
1039
1040
1041
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

1042
1043
1044
1045
            if recv_req.bootstrap_port is None:
                # Use default bootstrap port
                recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port

1046
1047
1048
1049
1050
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
1051
1052
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
1053
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
1054
                stream=recv_req.stream,
1055
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
1056
                input_embeds=recv_req.input_embeds,
Lianmin Zheng's avatar
Lianmin Zheng committed
1057
                custom_logit_processor=recv_req.custom_logit_processor,
1058
                return_hidden_states=recv_req.return_hidden_states,
1059
                eos_token_ids=self.model_config.hf_eos_token_id,
1060
                bootstrap_host=recv_req.bootstrap_host,
1061
                bootstrap_port=recv_req.bootstrap_port,
1062
                bootstrap_room=recv_req.bootstrap_room,
1063
                data_parallel_rank=recv_req.data_parallel_rank,
1064
                vocab_size=self.model_config.vocab_size,
1065
1066
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
1067

1068
1069
1070
            if self.disaggregation_mode != DisaggregationMode.NULL:
                # Invalid request for disaggregated mode
                if recv_req.bootstrap_room is None:
1071
                    error_msg = (
1072
1073
1074
                        f"Invalid request: Disaggregated request received without "
                        f"boostrap room id. {req.rid=}"
                    )
1075
1076
                    logger.error(error_msg)
                    prepare_abort(req, error_msg)
1077
1078
1079
                    self.stream_output([req], req.return_logprob)
                    return

1080
1081
1082
1083
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
1084
                req.set_finish_with_abort(
1085
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
1086
                )
1087
                self._add_request_to_queue(req)
1088
1089
                return
        else:
1090
1091
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
1092
            req = session.create_req(recv_req, self.tokenizer)
1093
            if isinstance(req.finished_reason, FINISH_ABORT):
1094
                self._add_request_to_queue(req)
1095
                return
1096

1097
        # Handle multimodal inputs
Mick's avatar
Mick committed
1098
1099
        if recv_req.mm_inputs is not None:
            image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
1100
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
1101
            req.origin_input_ids = self.pad_input_ids_func(
1102
                req.origin_input_ids, image_inputs
1103
            )
1104
            req.extend_image_inputs(image_inputs)
1105

1106
            if len(req.origin_input_ids) >= self.max_req_input_len:
1107
1108
1109
1110
1111
                req.set_finish_with_abort(
                    error_msg=(
                        "Multimodal prompt is too long after expanding multimodal tokens. "
                        f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                    )
1112
                )
1113
                self._add_request_to_queue(req)
1114
1115
                return

1116
        # Validate prompt length
1117
1118
1119
1120
1121
1122
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
1123
            req.set_finish_with_abort(error_msg)
1124
            self._add_request_to_queue(req)
1125
            return
1126

1127
        # Copy more attributes
1128
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
1129
1130
1131
1132
1133
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(req.origin_input_ids) - 1
        else:
            req.logprob_start_len = recv_req.logprob_start_len

1134
        if req.logprob_start_len >= len(req.origin_input_ids):
1135
            error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
1136
            req.logprob_start_len = len(req.origin_input_ids) - 1
1137
            req.set_finish_with_abort(error_msg)
1138
1139
1140
            self._add_request_to_queue(req)
            return

1141
1142
1143
1144
1145
1146
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
1147
            self.max_req_len - len(req.origin_input_ids) - 1,
1148
1149
        )

1150
1151
1152
1153
1154
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
1155
            or req.sampling_params.ebnf is not None
1156
            or req.sampling_params.structural_tag is not None
1157
1158
1159
1160
1161
1162
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
1163
1164
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
1165
1166
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
1167

1168
1169
1170
1171
1172
            value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
            req.grammar = value

            if not cache_hit:
                req.grammar_key = key
1173
                add_to_grammar_queue = True
1174
1175
1176
1177
            else:
                if value is INVALID_GRAMMAR_OBJ:  # We hit a cached invalid grammar.
                    error_msg = f"Invalid grammar request with cache hit: {key=}"
                    req.set_finish_with_abort(error_msg)
1178
1179

        if add_to_grammar_queue:
1180
            req.queue_time_start = time.perf_counter()
1181
1182
            self.grammar_queue.append(req)
        else:
1183
1184
1185
            self._add_request_to_queue(req)

    def _add_request_to_queue(self, req: Req):
1186
        req.queue_time_start = time.perf_counter()
Byron Hsu's avatar
Byron Hsu committed
1187
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
1188
            self._prefetch_kvcache(req)
Byron Hsu's avatar
Byron Hsu committed
1189
1190
1191
            self.disagg_prefill_bootstrap_queue.add(
                req, self.model_config.num_key_value_heads
            )
Byron Hsu's avatar
Byron Hsu committed
1192
1193
1194
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.add(req)
        else:
1195
            self._prefetch_kvcache(req)
Byron Hsu's avatar
Byron Hsu committed
1196
1197
            self.waiting_queue.append(req)

1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
    def _prefetch_kvcache(self, req: Req):
        if self.enable_hicache_storage:
            req.init_next_round_input(self.tree_cache)
            last_hash = req.last_host_node.get_last_hash_value()
            matched_len = len(req.prefix_indices) + req.host_hit_length
            # todo, free-form fetching, calculating hash keys on the fly
            if (matched_len > 0 and last_hash is not None) or matched_len == 0:
                new_input_tokens = req.fill_ids[matched_len:]
                self.tree_cache.prefetch_from_storage(
                    req.rid, req.last_host_node, new_input_tokens, last_hash
                )

1210
    def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
1211
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
Byron Hsu's avatar
Byron Hsu committed
1212
1213
1214
            self.disagg_prefill_bootstrap_queue.extend(
                reqs, self.model_config.num_key_value_heads
            )
1215
1216
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            # If this is a decode server, we put the request to the decode pending prealloc queue
1217
            self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
Byron Hsu's avatar
Byron Hsu committed
1218
1219
        else:
            self.waiting_queue.extend(reqs)
1220
1221
1222

    def handle_embedding_request(
        self,
1223
        recv_req: TokenizedEmbeddingReqInput,
1224
1225
1226
1227
1228
1229
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
woodx's avatar
woodx committed
1230
            token_type_ids=recv_req.token_type_ids,
1231
1232
1233
        )
        req.tokenizer = self.tokenizer

1234
1235
        # Handle multimodal inputs
        if recv_req.image_inputs is not None:
Mick's avatar
Mick committed
1236
            image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
1237
1238
1239
1240
1241
1242
1243
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
            req.origin_input_ids = self.pad_input_ids_func(
                req.origin_input_ids, image_inputs
            )
            req.extend_image_inputs(image_inputs)

            if len(req.origin_input_ids) >= self.max_req_input_len:
1244
1245
1246
1247
1248
                req.set_finish_with_abort(
                    error_msg=(
                        "Multimodal prompt is too long after expanding multimodal tokens. "
                        f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                    )
1249
                )
1250
                self._add_request_to_queue(req)
1251
1252
                return

1253
        # Validate prompts length
1254
        error_msg = validate_input_length(
1255
1256
1257
1258
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
1259
        if error_msg:
1260
            self._add_request_to_queue(req)
1261
            return
1262

1263
1264
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
1265
        self._add_request_to_queue(req)
1266

1267
1268
1269
1270
1271
    def self_check_during_idle(self):
        self.check_memory()
        self.check_tree_cache()
        self.new_token_ratio = self.init_new_token_ratio
        self.maybe_sleep_on_idle()
1272

Lianmin Zheng's avatar
Lianmin Zheng committed
1273
    def check_memory(self):
Hanming Lu's avatar
Hanming Lu committed
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
        if self.is_hybrid:
            (
                full_num_used,
                swa_num_used,
                _,
                _,
                full_available_size,
                full_evictable_size,
                swa_available_size,
                swa_evictable_size,
            ) = self._get_swa_token_info()
            memory_leak = full_num_used != 0 or swa_num_used != 0
            token_msg = (
                f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
                f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
            )
tarinkk's avatar
tarinkk committed
1290
        else:
Hanming Lu's avatar
Hanming Lu committed
1291
1292
1293
1294
1295
1296
            _, _, available_size, evictable_size = self._get_token_info()
            protected_size = self.tree_cache.protected_size()
            memory_leak = (available_size + evictable_size) != (
                self.max_total_num_tokens
                if not self.enable_hierarchical_cache
                else self.max_total_num_tokens - protected_size
Lianmin Zheng's avatar
Lianmin Zheng committed
1297
            )
Hanming Lu's avatar
Hanming Lu committed
1298
1299
1300
1301
            token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"

        if memory_leak:
            msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
Lianmin Zheng's avatar
Lianmin Zheng committed
1302
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1303

1304
1305
1306
1307
1308
1309
1310
1311
        if self.disaggregation_mode == DisaggregationMode.DECODE:
            req_total_size = (
                self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
            )
        else:
            req_total_size = self.req_to_token_pool.size

        if len(self.req_to_token_pool.free_slots) != req_total_size:
1312
            msg = (
1313
                "req_to_token_pool memory leak detected!"
1314
1315
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1316
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1317
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1318

1319
1320
        if (
            self.enable_metrics
1321
            and self.current_scheduler_metrics_enabled()
1322
            and time.perf_counter() > self.metrics_collector.last_log_time + 30
1323
1324
        ):
            # During idle time, also collect metrics every 30 seconds.
Hanming Lu's avatar
Hanming Lu committed
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
            if self.is_hybrid:
                (
                    full_num_used,
                    swa_num_used,
                    full_token_usage,
                    swa_token_usage,
                    _,
                    _,
                    _,
                    _,
                ) = self._get_swa_token_info()
                num_used = max(full_num_used, swa_num_used)
                token_usage = max(full_token_usage, swa_token_usage)
            else:
                num_used, token_usage, _, _ = self._get_token_info()
Lianmin Zheng's avatar
Lianmin Zheng committed
1340
            num_running_reqs = len(self.running_batch.reqs)
1341
1342
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
Hanming Lu's avatar
Hanming Lu committed
1343
            self.stats.token_usage = round(token_usage, 2)
1344
1345
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
1346
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1347
            self.metrics_collector.log_stats(self.stats)
1348
        self._publish_kv_events()
1349

Hanming Lu's avatar
Hanming Lu committed
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
    def check_tree_cache(self):
        if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache):
            self.tree_cache.sanity_check()

    def _get_token_info(self):
        available_size = self.token_to_kv_pool_allocator.available_size()
        evictable_size = self.tree_cache.evictable_size()
        num_used = self.max_total_num_tokens - (available_size + evictable_size)
        token_usage = num_used / self.max_total_num_tokens
        return num_used, token_usage, available_size, evictable_size

    def _get_swa_token_info(self):
        full_available_size = self.token_to_kv_pool_allocator.full_available_size()
        full_evictable_size = self.tree_cache.full_evictable_size()
        swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
        swa_evictable_size = self.tree_cache.swa_evictable_size()
        full_num_used = self.full_tokens_per_layer - (
            full_available_size + full_evictable_size
        )
        swa_num_used = self.swa_tokens_per_layer - (
            swa_available_size + swa_evictable_size
        )
        full_token_usage = full_num_used / self.full_tokens_per_layer
        swa_token_usage = swa_num_used / self.swa_tokens_per_layer
        return (
            full_num_used,
            swa_num_used,
            full_token_usage,
            swa_token_usage,
            full_available_size,
            full_evictable_size,
            swa_available_size,
            swa_evictable_size,
        )

1385
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1386
        # Merge the prefill batch into the running batch
1387
1388
1389
1390
1391
1392
1393
1394
        chunked_req_to_exclude = set()
        if self.chunked_req:
            # Move the chunked request out of the batch so that we can merge
            # only finished requests to running_batch.
            chunked_req_to_exclude.add(self.chunked_req)
            self.tree_cache.cache_unfinished_req(self.chunked_req)
            # chunked request keeps its rid but will get a new req_pool_idx
            self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
Lianmin Zheng's avatar
Lianmin Zheng committed
1395
        if self.last_batch and self.last_batch.forward_mode.is_extend():
1396
1397
1398
1399
            if self.last_batch.chunked_req is not None:
                # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
                # We need to discard it.
                chunked_req_to_exclude.add(self.last_batch.chunked_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1400

1401
            # Filter batch
1402
            last_bs = self.last_batch.batch_size()
1403
1404
1405
            self.last_batch.filter_batch(
                chunked_req_to_exclude=list(chunked_req_to_exclude)
            )
1406
            if self.last_batch.batch_size() < last_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1407
                self.running_batch.batch_is_full = False
1408

1409
            # Merge the new batch into the running batch
1410
            if not self.last_batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1411
                if self.running_batch.is_empty():
1412
1413
                    self.running_batch = self.last_batch
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1414
                    # Merge running_batch with prefill batch
1415
                    self.running_batch.merge_batch(self.last_batch)
1416

1417
        new_batch = self.get_new_batch_prefill()
1418

1419
1420
1421
1422
1423
        need_dp_attn_preparation = require_mlp_sync(self.server_args)

        if need_dp_attn_preparation and not self.spec_algorithm.is_none():
            # In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
            # We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
1424
            new_batch = self.prepare_mlp_sync_batch(new_batch)
1425
1426
1427
            need_dp_attn_preparation = new_batch is None

        if new_batch is not None:
1428
1429
1430
1431
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
1432
            if not self.running_batch.is_empty():
1433
                self.running_batch = self.update_running_batch(self.running_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1434
1435
1436
                ret = self.running_batch if not self.running_batch.is_empty() else None
            else:
                ret = None
1437

1438
1439
        # Handle DP attention
        if need_dp_attn_preparation:
1440
            ret = self.prepare_mlp_sync_batch(ret)
1441
1442

        return ret
1443

1444
1445
1446
1447
1448
1449
    def get_num_allocatable_reqs(self, running_bs):
        res = global_server_args_dict["max_micro_batch_size"] - running_bs
        if self.pp_size > 1:
            res = min(res, self.req_to_token_pool.available_size())
        return res

Lianmin Zheng's avatar
Lianmin Zheng committed
1450
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
1451
        # Check if the grammar is ready in the grammar queue
1452
        if self.grammar_queue:
1453
            self.move_ready_grammar_requests()
1454

Lianmin Zheng's avatar
Lianmin Zheng committed
1455
1456
        # Handle the cases where prefill is not allowed
        if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1457
            self.running_batch.batch_is_full or len(self.waiting_queue) == 0
1458
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1459
1460
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1461
        running_bs = len(self.running_batch.reqs)
1462
        # Ignore the check if self.chunked_req is not None.
1463
1464
1465
1466
1467
        # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
        # as the space for the chunked request has just been released.
        # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
        # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
        if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
Lianmin Zheng's avatar
Lianmin Zheng committed
1468
            self.running_batch.batch_is_full = True
1469
1470
            return None

1471
        if self.enable_hierarchical_cache:
1472
            self.tree_cache.check_hicache_events()
1473

1474
        # Get priority queue
1475
        self.policy.calc_priority(self.waiting_queue)
1476

Lianmin Zheng's avatar
Lianmin Zheng committed
1477
        # Prefill policy
1478
        adder = PrefillAdder(
1479
            self.page_size,
1480
            self.tree_cache,
1481
            self.token_to_kv_pool_allocator,
1482
1483
1484
1485
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1486
            running_bs if self.is_mixed_chunk else 0,
1487
1488
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1489
        if self.chunked_req is not None:
1490
1491
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1492

1493
        if self.enable_lora:
Lianmin Zheng's avatar
Lianmin Zheng committed
1494
1495
            lora_set = set([req.lora_path for req in self.running_batch.reqs])

1496
        # Get requests from the waiting queue to a new prefill batch
1497
1498
        for req in self.waiting_queue:
            if (
1499
                self.enable_lora
1500
1501
1502
1503
1504
1505
1506
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1507
                self.running_batch.batch_is_full = True
1508
1509
                break

1510
            if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
Lianmin Zheng's avatar
Lianmin Zheng committed
1511
                self.running_batch.batch_is_full = True
1512
                break
1513

Byron Hsu's avatar
Byron Hsu committed
1514
1515
1516
1517
1518
1519
1520
            if self.disaggregation_mode == DisaggregationMode.PREFILL:
                # In prefill mode, prealloc queue and transfer queue can also take memory,
                # so we need to check if the available size for the actual available size.
                if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
                    self.running_batch.batch_is_full = True
                    break

1521
1522
1523
            if self.enable_hicache_storage:
                self.tree_cache.check_prefetch_progress(req.rid)

1524
1525
            req.init_next_round_input(self.tree_cache)
            res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
1526

1527
1528
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1529
1530
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
Lianmin Zheng's avatar
Lianmin Zheng committed
1531
1532
                        self.running_batch.batch_is_full = len(
                            adder.can_run_list
1533
                        ) > 0 or (not self.running_batch.is_empty())
1534
                    else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1535
                        self.running_batch.batch_is_full = True
1536
1537
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1538
        # Update waiting queue
1539
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1540
1541
        if len(can_run_list) == 0:
            return None
1542
1543
1544
1545

        if self.enable_metrics:
            # only record queue time when enable_metrics is True to avoid overhead
            for req in can_run_list:
1546
                req.queue_time_end = time.perf_counter()
1547

Lianmin Zheng's avatar
Lianmin Zheng committed
1548
1549
1550
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1551

1552
1553
1554
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1555

1556
1557
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1558

1559
        # Print stats
1560
        if self.current_scheduler_metrics_enabled():
1561
            self.log_prefill_stats(adder, can_run_list, running_bs)
1562

Lianmin Zheng's avatar
Lianmin Zheng committed
1563
        # Create a new batch
1564
1565
1566
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
1567
            self.token_to_kv_pool_allocator,
1568
            self.tree_cache,
1569
            self.model_config,
1570
            self.enable_overlap,
1571
            self.spec_algorithm,
1572
            self.server_args.enable_custom_logit_processor,
1573
            chunked_req=self.chunked_req,
1574
        )
1575
1576
        if self.enable_hierarchical_cache:
            # todo (zhiqiang): disable cuda graph execution if hicache loading triggered
1577
1578
1579
            new_batch.hicache_consumer_index = (
                self.tree_cache.ready_to_load_host_cache()
            )
1580

1581
        new_batch.prepare_for_extend()
1582

Lianmin Zheng's avatar
Lianmin Zheng committed
1583
        # Mixed-style chunked prefill
1584
1585
        if (
            self.is_mixed_chunk
Lianmin Zheng's avatar
Lianmin Zheng committed
1586
            and not self.running_batch.is_empty()
1587
1588
1589
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1590
1591
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1592
                self.running_batch.prepare_for_decode()
1593
1594
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1595
1596
1597
            self.running_batch = ScheduleBatch(
                reqs=[], batch_is_full=self.running_batch.batch_is_full
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1598
1599
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1600
1601
1602

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1603
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1604
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1605
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1606

1607
1608
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1609
1610
            batch.batch_is_full = False
            return batch
1611

Lianmin Zheng's avatar
Lianmin Zheng committed
1612
        # Check if decode out of memory
1613
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1614
            TEST_RETRACT and batch.batch_size() > 10
1615
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1616
1617
            old_ratio = self.new_token_ratio

1618
            retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
1619
            num_retracted_reqs = len(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1620
            self.new_token_ratio = new_token_ratio
1621

Lianmin Zheng's avatar
Lianmin Zheng committed
1622
            logger.info(
1623
                "KV cache pool is full. Retract requests. "
1624
                f"#retracted_reqs: {num_retracted_reqs}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
1625
1626
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
1627

1628
            self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
1629
            self.total_retracted_reqs += num_retracted_reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1630
1631
        else:
            self.new_token_ratio = max(
1632
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1633
1634
1635
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1636
        if batch.batch_size() < initial_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1637
            batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1638
1639

        # Update batch tensors
1640
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1641
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1642

1643
1644
1645
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1646
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1647
1648
        self.forward_ct += 1

1649
1650
        # Whether to run the profiler
        self._profile_batch_predicate(batch)
1651
1652
1653
1654
        if self.forward_sleep_time is not None:
            logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
            time.sleep(self.forward_sleep_time)

1655
        # Run forward
1656
        if self.is_generation:
1657
1658
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
1659
1660
1661
1662
1663

                # update the consumer index of hicache to the running batch
                self.tp_worker.set_hicache_consumer(
                    model_worker_batch.hicache_consumer_index
                )
1664
                if self.pp_group.is_last_rank:
1665
                    logits_output, next_token_ids, can_run_cuda_graph = (
1666
1667
1668
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
                else:
1669
                    pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
1670
1671
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
1672
                bid = model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
1673
            else:
1674
1675
1676
                (
                    logits_output,
                    next_token_ids,
1677
                    bid,
1678
                    num_accepted_tokens,
1679
                    can_run_cuda_graph,
1680
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
1681
1682
1683
                bs = batch.batch_size()
                self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
                self.spec_num_total_forward_ct += bs
1684
                self.num_generated_tokens += num_accepted_tokens
1685
1686
1687

            if self.pp_group.is_last_rank:
                batch.output_ids = next_token_ids
1688

1689
1690
1691
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
1692
            if batch.return_logprob or self.spec_algorithm.is_eagle():
1693
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
1694
1695
1696
            else:
                extend_input_len_per_req = None
            if batch.return_logprob:
1697
1698
1699
1700
1701
1702
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_logprob_start_len_per_req = None

1703
            ret = GenerationBatchResult(
1704
1705
1706
1707
1708
1709
1710
                logits_output=logits_output if self.pp_group.is_last_rank else None,
                pp_hidden_states_proxy_tensors=(
                    pp_hidden_states_proxy_tensors
                    if not self.pp_group.is_last_rank
                    else None
                ),
                next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
1711
1712
                extend_input_len_per_req=extend_input_len_per_req,
                extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1713
                bid=bid,
1714
                can_run_cuda_graph=can_run_cuda_graph,
1715
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1716
1717
1718
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1719
1720
1721
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1722
        return ret
Chayenne's avatar
Chayenne committed
1723

1724
1725
1726
1727
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
1728
        launch_done: Optional[threading.Event] = None,
1729
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1730
        if batch.forward_mode.is_decode():
1731
            self.process_batch_result_decode(batch, result, launch_done)
1732
        elif batch.forward_mode.is_extend():
1733
            self.process_batch_result_prefill(batch, result, launch_done)
1734
1735
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1736
                self.tp_worker.resolve_last_batch_result(launch_done)
1737
                self.set_next_batch_sampling_info_done(batch)
1738
        elif batch.forward_mode.is_dummy_first():
1739
            self.set_next_batch_sampling_info_done(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1740

1741
1742
1743
1744
1745
1746
1747
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

1748
1749
    def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
        return self.prepare_mlp_sync_batch_raw(
1750
1751
1752
            local_batch,
            dp_size=self.server_args.dp_size,
            attn_tp_size=self.attn_tp_size,
1753
            tp_group=self.tp_group,
1754
1755
1756
1757
            get_idle_batch=self.get_idle_batch,
            disable_cuda_graph=self.server_args.disable_cuda_graph,
            spec_algorithm=self.spec_algorithm,
            speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1758
1759
1760
            enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
            enable_deepep_moe=self.server_args.enable_deepep_moe,
            deepep_mode=DeepEPMode[self.server_args.deepep_mode],
1761
            require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
1762
            disable_overlap_schedule=self.server_args.disable_overlap_schedule,
1763
1764
1765
        )

    @staticmethod
1766
    def prepare_mlp_sync_batch_raw(
1767
1768
1769
        local_batch: ScheduleBatch,
        dp_size,
        attn_tp_size: int,
1770
        tp_group,
1771
1772
1773
1774
        get_idle_batch,
        disable_cuda_graph: bool,
        spec_algorithm,
        speculative_num_draft_tokens,
1775
1776
1777
        enable_two_batch_overlap: bool,
        enable_deepep_moe: bool,
        deepep_mode: DeepEPMode,
1778
        require_mlp_tp_gather: bool,
1779
        disable_overlap_schedule: bool,
1780
    ):
1781
1782
1783
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
1784
            num_tokens_for_logprob = 0
1785
1786
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
1787
            num_tokens_for_logprob = num_tokens
1788
1789
        else:
            num_tokens = local_batch.extend_num_tokens
1790
            num_tokens_for_logprob = sum(
Lianmin Zheng's avatar
Lianmin Zheng committed
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
                [
                    # We should have at least 1 token for sample in every case.
                    max(extend_len - logprob_start_len, 1)
                    for logprob_start_len, extend_len in zip(
                        local_batch.extend_logprob_start_lens, local_batch.extend_lens
                    )
                ]
            )

        if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
            can_cuda_graph = 1
        else:
            can_cuda_graph = 0

        is_extend_in_batch = (
            local_batch.forward_mode.is_extend() if local_batch else False
        )
1808
1809

        tbo_preparer = TboDPAttentionPreparer()
1810
1811
1812
1813
1814
1815
        if disable_overlap_schedule:
            group = tp_group.device_group
            device = tp_group.device
        else:
            group = tp_group.cpu_group
            device = "cpu"
1816

Lianmin Zheng's avatar
Lianmin Zheng committed
1817
1818
1819
1820
        local_info = torch.tensor(
            [
                num_tokens,
                can_cuda_graph,
1821
                num_tokens_for_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
1822
                is_extend_in_batch,
1823
1824
1825
1826
1827
1828
                *tbo_preparer.prepare_all_gather(
                    local_batch,
                    deepep_mode,
                    enable_deepep_moe,
                    enable_two_batch_overlap,
                ),
Lianmin Zheng's avatar
Lianmin Zheng committed
1829
1830
            ],
            dtype=torch.int64,
1831
            device=device,
Lianmin Zheng's avatar
Lianmin Zheng committed
1832
1833
        )
        global_info = torch.empty(
1834
            (dp_size, attn_tp_size, 6),
Lianmin Zheng's avatar
Lianmin Zheng committed
1835
            dtype=torch.int64,
1836
            device=device,
Lianmin Zheng's avatar
Lianmin Zheng committed
1837
        )
1838
        torch.distributed.all_gather_into_tensor(
Lianmin Zheng's avatar
Lianmin Zheng committed
1839
1840
            global_info.flatten(),
            local_info,
1841
            group=group,
1842
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1843
1844
1845
1846
        global_num_tokens = global_info[:, 0, 0].tolist()
        can_cuda_graph = min(global_info[:, 0, 1].tolist())
        global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
        is_extend_in_batch = global_info[:, 0, 3].tolist()
1847

1848
1849
1850
1851
        tbo_split_seq_index, global_forward_mode = tbo_preparer.compute_output(
            global_info[:, :, 4:6]
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1852
        if local_batch is None and max(global_num_tokens) > 0:
1853
            local_batch = get_idle_batch()
1854
1855

        if local_batch is not None:
1856
            # TODO: handle the case when moe_dense_tp_size != 1
1857
            if not require_mlp_tp_gather:
1858
1859
1860
1861
1862
1863
1864
                local_batch.global_num_tokens = [num_tokens]
                local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
            else:
                local_batch.global_num_tokens = global_num_tokens
                local_batch.global_num_tokens_for_logprob = (
                    global_num_tokens_for_logprob
                )
1865
            local_batch.is_extend_in_batch = any(is_extend_in_batch)
1866
1867
            local_batch.tbo_split_seq_index = tbo_split_seq_index
            local_batch.global_forward_mode = global_forward_mode
1868

1869
            # Check forward mode for cuda graph
1870
            if not disable_cuda_graph:
Lianmin Zheng's avatar
Lianmin Zheng committed
1871
                local_batch.can_run_dp_cuda_graph = can_cuda_graph
1872

1873
        return local_batch
1874
1875
1876
1877
1878

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
1879
            self.token_to_kv_pool_allocator,
1880
1881
1882
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
1883
            self.spec_algorithm,
1884
            self.server_args.enable_custom_logit_processor,
1885
1886
1887
1888
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1889
1890
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
1891

1892
        num_ready_reqs = 0
1893
        num_timeout_reqs = 0
1894
1895
        for req in self.grammar_queue:
            try:
1896
1897
1898
                if req.finished():  # It is aborted by AbortReq
                    num_ready_reqs += 1
                    continue
1899
                req.grammar = req.grammar.result(timeout=0.03)
1900
1901
1902
1903
1904
                self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
                if req.grammar is INVALID_GRAMMAR_OBJ:
                    req.set_finish_with_abort(
                        f"Invalid grammar request: {req.grammar_key=}"
                    )
1905
1906
                num_ready_reqs += 1
            except futures._base.TimeoutError:
1907
                req.grammar_wait_ct += 1
1908
1909
                # NOTE(lianmin): this timeout is the waiting time of the above line. It is
                # not the waiting time from it enters the grammar queue.
1910
                if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
1911
                    num_timeout_reqs = 1
1912
1913
                break

1914
        if self.server_args.enable_dp_attention:
1915
1916
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
1917
        else:
1918
1919
1920
1921
1922
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
1923
            tensor = torch.tensor([num_ready_reqs, num_timeout_reqs], dtype=torch.int32)
1924
1925
1926
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
1927
            num_ready_reqs_max, num_timeout_reqs_max = tensor.tolist()
1928

1929
            for i in range(num_ready_reqs, num_ready_reqs_max):
1930
                req = self.grammar_queue[i]
1931
1932
                if req.finished():  # It is aborted by AbortReq
                    continue
1933
                req.grammar = req.grammar.result()
1934
1935
1936
1937
1938
1939
1940
1941
                self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
                if req.grammar is INVALID_GRAMMAR_OBJ:
                    req.set_finish_with_abort(
                        f"Invalid grammar request: {req.grammar_key=}"
                    )
        else:
            num_ready_reqs_max = num_ready_reqs
            num_timeout_reqs_max = num_timeout_reqs
1942

1943
1944
1945
1946
1947
1948
1949
        for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
            req = self.grammar_queue[i]
            req.grammar.cancel()
            error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
            req.set_finish_with_abort(error_msg)
            self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
        num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
1950

1951
        self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1952
1953
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1954
1955
1956
1957
1958
1959
1960
    def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
        if batch.next_batch_sampling_info:
            if batch.next_batch_sampling_info.grammars is not None:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
                self.current_stream.synchronize()
            batch.next_batch_sampling_info.sampling_info_done.set()

1961
1962
1963
    def watchdog_thread(self):
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
        self.watchdog_last_forward_ct = 0
1964
        self.watchdog_last_time = time.perf_counter()
1965
1966

        while True:
1967
            current = time.perf_counter()
1968
1969
1970
1971
1972
1973
1974
1975
1976
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if current > self.watchdog_last_time + self.watchdog_timeout:
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)

Lianmin Zheng's avatar
Lianmin Zheng committed
1977
1978
        if not disable_request_logging():
            # Print batch size and memory pool info to check whether there are de-sync issues.
Hanming Lu's avatar
Hanming Lu committed
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
            if self.is_hybrid:
                (
                    _,
                    _,
                    _,
                    _,
                    full_available_size,
                    full_evictable_size,
                    swa_available_size,
                    swa_evictable_size,
                ) = self._get_swa_token_info()
                info_msg = (
                    f"{full_available_size=}, "
                    f"{full_evictable_size=}, "
                    f"{swa_available_size=}, "
                    f"{swa_evictable_size=}, "
                )
            else:
                _, _, available_size, evictable_size = self._get_token_info()
                info_msg = f"{available_size=}, " f"{evictable_size=}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
1999
2000
2001
            logger.error(
                f"{self.cur_batch.batch_size()=}, "
                f"{self.cur_batch.reqs=}, "
Hanming Lu's avatar
Hanming Lu committed
2002
                f"{info_msg}"
Lianmin Zheng's avatar
Lianmin Zheng committed
2003
2004
            )

2005
        pyspy_dump_schedulers()
Lianmin Zheng's avatar
Lianmin Zheng committed
2006
        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
2007
2008
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
2009
2010

        # Wait for some time so that the parent process can print the error.
2011
2012
2013
        time.sleep(5)
        self.parent_process.send_signal(signal.SIGQUIT)

2014
2015
2016
    def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
        success = self.flush_cache()
        return FlushCacheReqOutput(success=success)
2017

2018
    def flush_cache(self):
2019
        """Flush the memory pool and cache."""
2020
2021
2022
2023
2024
        if (
            len(self.waiting_queue) == 0
            and self.running_batch.is_empty()
            and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
        ):
2025
2026
            self.cur_batch = None
            self.last_batch = None
2027
            self.tree_cache.reset()
2028
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
2029
                self.grammar_backend.reset()
2030
            self.req_to_token_pool.clear()
2031
            self.token_to_kv_pool_allocator.clear()
2032
2033
2034

            if not self.spec_algorithm.is_none():
                self.draft_worker.model_runner.req_to_token_pool.clear()
2035
                self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
2036
2037
2038
2039
2040

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
2041
2042
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
2043
2044
2045
2046
2047
2048
2049
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
2050
                f"#running-req: {len(self.running_batch.reqs)}"
2051
2052
2053
2054
            )
            if_success = False
        return if_success

Liangsheng Yin's avatar
Liangsheng Yin committed
2055
2056
    def get_load(self):
        # TODO(lsyin): use dynamically maintained num_waiting_tokens
Hanming Lu's avatar
Hanming Lu committed
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
        if self.is_hybrid:
            load_full = (
                self.full_tokens_per_layer
                - self.token_to_kv_pool_allocator.full_available_size()
                - self.tree_cache.full_evictable_size()
            )
            load_swa = (
                self.swa_tokens_per_layer
                - self.token_to_kv_pool_allocator.swa_available_size()
                - self.tree_cache.swa_evictable_size()
            )
            load = max(load_full, load_swa)
        else:
            load = (
                self.max_total_num_tokens
                - self.token_to_kv_pool_allocator.available_size()
                - self.tree_cache.evictable_size()
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
        load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            load += sum(
                len(req.origin_input_ids)
                for req in self.disagg_prefill_bootstrap_queue.queue
            )
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            load += sum(
                len(req.req.origin_input_ids)
                for req in self.disagg_decode_prealloc_queue.queue
            )

        return load

2089
2090
2091
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
2092
2093
2094
2095
2096
2097
2098
2099
2100
        ret["memory_usage"] = {
            "weight": round(
                self.tp_worker.worker.model_runner.weight_load_mem_usage, 2
            ),
            "kvcache": round(
                self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
            ),
            "token_capacity": int(self.max_total_num_tokens),
        }
2101
2102
2103
2104
2105
2106

        if not _is_cpu:
            ret["memory_usage"]["cuda_graph"] = round(
                self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
            )

2107
2108
2109
2110
2111
2112
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )
        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
Liangsheng Yin's avatar
Liangsheng Yin committed
2113
2114
2115
2116

        ret["load"] = self.get_load()

        return GetInternalStateReqOutput(internal_state=ret)
2117
2118
2119
2120
2121

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
2122
                "max_micro_batch_size",
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
2133
2134
2135
2136
2137
2138
2139
2140
            elif k == "max_micro_batch_size" and (
                v > self.max_running_requests // self.pp_size or v < 1
            ):
                logging.warning(
                    f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
                )
                if_success = False
                break
2141
2142
2143
2144
2145
2146
2147
2148
2149
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
2150
            logger.info(f"Global server args updated! {global_server_args_dict=}")
2151
2152
2153
2154
2155
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
    def handle_rpc_request(self, recv_req: RpcReqInput):
        # Handle RPC requests
        logger.info(
            f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
        )

        success = True
        exec = None
        try:
            func = getattr(self, recv_req.method)
            func(recv_req.parameters)
        except Exception as e:
            success = False
            exec = e
            logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")

        barrier()
        return RpcReqOutput(success, "" if not exec else str(exec))

2175
2176
    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
Lianmin Zheng's avatar
Lianmin Zheng committed
2177
        to_del = []
2178
        for i, req in enumerate(self.waiting_queue):
2179
            if recv_req.abort_all or req.rid.startswith(recv_req.rid):
Lianmin Zheng's avatar
Lianmin Zheng committed
2180
                to_del.append(i)
2181

Lianmin Zheng's avatar
Lianmin Zheng committed
2182
        # Sort in reverse order to avoid index issues when deleting
Lianmin Zheng's avatar
Lianmin Zheng committed
2183
        for i in reversed(to_del):
2184
2185
2186
            # Abort method 1: directly pop from the queue
            # This only works for requests that have not started anything.
            # We still need to send something back to TokenizerManager to clean up the state.
Lianmin Zheng's avatar
Lianmin Zheng committed
2187
            req = self.waiting_queue.pop(i)
Lianmin Zheng's avatar
Lianmin Zheng committed
2188
            self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
2189
            logger.debug(f"Abort queued request. {req.rid=}")
2190

2191
2192
2193
2194
2195
        # Delete the requests in the grammar queue
        for req in self.grammar_queue:
            # Abort method 2: call `set_finish_with_abort`
            # The request will still run one prefill forward pass.
            # In this case, we change the input_ids to be only one token to make this prefill cheap.
2196
            if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2197
                logger.debug(f"Abort grammar queue request. {req.rid=}")
2198
2199
                if req.grammar:
                    req.grammar.cancel()
2200
2201
                req.set_finish_with_abort("Aborted by AbortReq.")

2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
        # Delete requests not in the waiting queue when PD disaggregation is enabled
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            # Abort requests that have not yet been bootstrapped
            for i, req in enumerate(self.disagg_prefill_bootstrap_queue.queue):
                logger.debug(f"Abort bootstrap queue request. {req.rid=}")
                if recv_req.abort_all or req.rid.startswith(recv_req.rid):
                    if hasattr(req.disagg_kv_sender, "abort"):
                        req.disagg_kv_sender.abort()

            # Abort in-flight requests
            for i, req in enumerate(self.disagg_prefill_inflight_queue):
                logger.debug(f"Abort inflight queue request. {req.rid=}")
                if recv_req.abort_all or req.rid.startswith(recv_req.rid):
                    if hasattr(req.disagg_kv_sender, "abort"):
                        req.disagg_kv_sender.abort()

        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            # Abort requests that have not yet finished preallocation
            for i, decode_req in enumerate(self.disagg_decode_prealloc_queue.queue):
                logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
                if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
                    if hasattr(decode_req.kv_receiver, "abort"):
                        decode_req.kv_receiver.abort()

            # Abort requests waiting for kvcache to release tree cache
            for i, decode_req in enumerate(self.disagg_decode_transfer_queue.queue):
                logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
                if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
                    if hasattr(decode_req.kv_receiver, "abort"):
                        decode_req.kv_receiver.abort()

2233
        # Delete requests in the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
2234
2235
2236
2237
2238
2239
        if self.cur_batch is self.running_batch or self.cur_batch is None:
            reqs = self.running_batch.reqs
        else:
            reqs = self.running_batch.reqs + self.cur_batch.reqs

        for req in reqs:
2240
2241
2242
            if not req.finished() and (
                recv_req.abort_all or req.rid.startswith(recv_req.rid)
            ):
2243
2244
2245
                # Abort method 3: set `to_abort=True`
                # The request will still run one decode forward pass.
                # Then we reuse all existing code to clean up the KV cache allocation.
Lianmin Zheng's avatar
Lianmin Zheng committed
2246
2247
                logger.debug(f"Abort running request. {req.rid=}")
                req.to_abort = True
2248

2249
2250
2251
    def _pause_engine(self) -> Tuple[List[Req], int]:
        raise NotImplementedError()

2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
    def load_lora_adapter(
        self, recv_req: LoadLoRAAdapterReqInput
    ) -> LoadLoRAAdapterReqOutput:
        """In-place loading a new lora adapter from disk or huggingface."""

        result = self.tp_worker.load_lora_adapter(recv_req)
        return result

    def unload_lora_adapter(
        self, recv_req: UnloadLoRAAdapterReqInput
    ) -> UnloadLoRAAdapterReqOutput:
        """Unload the lora adapter."""

        result = self.tp_worker.unload_lora_adapter(recv_req)
        return result

2268
2269
2270
2271
2272
2273
2274
    def slow_down(self, recv_req: SlowDownReqInput):
        t = recv_req.forward_sleep_time
        if t is not None and t <= 0:
            t = None
        self.forward_sleep_time = t
        return SlowDownReqOutput()

2275
2276
    def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
        if recv_req == ExpertDistributionReq.START_RECORD:
2277
            get_global_expert_distribution_recorder().start_record()
2278
        elif recv_req == ExpertDistributionReq.STOP_RECORD:
2279
            get_global_expert_distribution_recorder().stop_record()
2280
        elif recv_req == ExpertDistributionReq.DUMP_RECORD:
2281
            get_global_expert_distribution_recorder().dump_record()
2282
        else:
2283
            raise ValueError(f"Unrecognized ExpertDistributionReq value: {recv_req=}")
2284
        return ExpertDistributionReqOutput()
2285

2286
    def open_session(self, recv_req: OpenSessionReqInput):
2287
2288
2289
2290
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
2291
            return OpenSessionReqOutput(session_id, False)
2292
        elif session_id is None:
2293
            logger.warning("session id is None, cannot open.")
2294
            return OpenSessionReqOutput(session_id, False)
2295
2296
2297
2298
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
2299
            return OpenSessionReqOutput(session_id, True)
2300
2301
2302
2303
2304
2305
2306
2307
2308

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

2309
2310
    def get_print_prefix(self):
        prefix = ""
2311
2312
        if self.attn_dp_rank is not None:
            prefix += f" DP{self.attn_dp_rank}"
2313
2314
2315
2316
2317
2318
        if self.server_args.tp_size > 1:
            prefix += f" TP{self.tp_rank}"
        if self.pp_size > 1:
            prefix += f" PP{self.pp_rank}"
        return prefix

2319
2320
    def current_scheduler_metrics_enabled(self):
        return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers
2321

2322
2323
2324
    def maybe_sleep_on_idle(self):
        if self.idle_sleeper is not None:
            self.idle_sleeper.maybe_sleep()
2325

2326

2327
2328
2329
2330
2331
2332
2333
class IdleSleeper:
    """
    In setups which have long inactivity periods it is desirable to reduce
    system power consumption when sglang does nothing. This would lead not only
    to power savings, but also to more CPU thermal headroom when a request
    eventually comes. This is important in cases when multiple GPUs are connected
    as each GPU would otherwise pin one thread at 100% CPU usage.
2334

2335
2336
2337
    The simplest solution is to use zmq.Poller on all sockets that may receive
    data that needs handling immediately.
    """
2338

2339
2340
2341
2342
2343
2344
2345
    def __init__(self, sockets):
        self.poller = zmq.Poller()
        for s in sockets:
            self.poller.register(s, zmq.POLLIN)

    def maybe_sleep(self):
        self.poller.poll(1000)
2346

2347

2348
2349
def is_health_check_generate_req(recv_req):
    return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
2350

2351
2352
2353

def is_work_request(recv_req):
    return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
2354
2355


2356
2357
2358
2359
2360
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
2361
    pp_rank: int,
2362
    dp_rank: Optional[int],
2363
    pipe_writer,
2364
):
2365
    # Generate the prefix
2366
2367
2368
2369
2370
2371
2372
    prefix = ""
    if dp_rank is not None:
        prefix += f" DP{dp_rank}"
    if server_args.tp_size > 1:
        prefix += f" TP{tp_rank}"
    if server_args.pp_size > 1:
        prefix += f" PP{pp_rank}"
2373

2374
    # Config the process
2375
    setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
2376
    faulthandler.enable()
2377
    kill_itself_when_parent_died()
2378
    parent_process = psutil.Process().parent()
2379

2380
2381
2382
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
2383

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
2384
    # Configure the logger
2385
    configure_logger(server_args, prefix=prefix)
2386
    suppress_other_loggers()
2387

2388
    # Set cpu affinity to this gpu process
2389
2390
2391
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

2392
    # Create a scheduler and run the event loop
2393
    try:
2394
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
2395
        pipe_writer.send(
Mick's avatar
Mick committed
2396
2397
2398
2399
2400
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2401
        )
Byron Hsu's avatar
Byron Hsu committed
2402

2403
        disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
Byron Hsu's avatar
Byron Hsu committed
2404
        if disaggregation_mode == DisaggregationMode.NULL:
2405
2406
2407
            if server_args.pp_size > 1:
                scheduler.event_loop_pp()
            elif scheduler.enable_overlap:
Byron Hsu's avatar
Byron Hsu committed
2408
2409
2410
2411
                scheduler.event_loop_overlap()
            else:
                scheduler.event_loop_normal()
        elif disaggregation_mode == DisaggregationMode.PREFILL:
2412
2413
2414
2415
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_prefill()
            else:
                scheduler.event_loop_normal_disagg_prefill()
2416

Byron Hsu's avatar
Byron Hsu committed
2417
        elif disaggregation_mode == DisaggregationMode.DECODE:
2418
2419
2420
2421
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_decode()
            else:
                scheduler.event_loop_normal_disagg_decode()
Byron Hsu's avatar
Byron Hsu committed
2422

2423
    except Exception:
2424
2425
2426
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)