scheduler.py 121 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import datetime
17
import faulthandler
18
import logging
19
import os
20
import signal
21
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
22
import threading
23
import time
24
from collections import defaultdict, deque
Lianmin Zheng's avatar
Lianmin Zheng committed
25
from concurrent import futures
26
from dataclasses import dataclass
27
from http import HTTPStatus
28
from pathlib import Path
29
from types import SimpleNamespace
30
from typing import Dict, List, Optional, Tuple, Union
31

32
import psutil
33
import setproctitle
34
import torch
35
import zmq
36
from torch.distributed import barrier
37

38
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
39
from sglang.srt.configs.model_config import ModelConfig
40
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
41
42
43
44
from sglang.srt.constrained.base_grammar_backend import (
    INVALID_GRAMMAR_OBJ,
    create_grammar_backend,
)
Byron Hsu's avatar
Byron Hsu committed
45
46
47
48
49
from sglang.srt.disaggregation.decode import (
    DecodePreallocQueue,
    DecodeTransferQueue,
    SchedulerDisaggregationDecodeMixin,
)
50
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
Byron Hsu's avatar
Byron Hsu committed
51
52
53
54
55
56
from sglang.srt.disaggregation.prefill import (
    PrefillBootstrapQueue,
    SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
57
    MetadataBuffers,
Byron Hsu's avatar
Byron Hsu committed
58
    ReqToMetadataIdxAllocator,
59
    TransferBackend,
60
    prepare_abort,
Byron Hsu's avatar
Byron Hsu committed
61
)
62
from sglang.srt.distributed import get_pp_group, get_world_group
fzyzcjy's avatar
fzyzcjy committed
63
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
xm:D's avatar
xm:D committed
64
65
66
67
68
from sglang.srt.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
69
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
70
71
72
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
    AbortReq,
73
    CloseSessionReqInput,
74
    ExpertDistributionReq,
75
    ExpertDistributionReqOutput,
76
77
    FlushCacheReqInput,
    FlushCacheReqOutput,
78
79
    GetInternalStateReq,
    GetInternalStateReqOutput,
80
81
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
82
    HealthCheckOutput,
83
84
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
85
86
    LoadLoRAAdapterReqInput,
    LoadLoRAAdapterReqOutput,
87
88
    OpenSessionReqInput,
    OpenSessionReqOutput,
89
    ProfileReq,
90
91
    ProfileReqOutput,
    ProfileReqType,
92
93
94
95
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
96
97
    RpcReqInput,
    RpcReqOutput,
98
99
    SetInternalStateReq,
    SetInternalStateReqOutput,
100
101
    SlowDownReqInput,
    SlowDownReqOutput,
102
103
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
104
105
    UnloadLoRAAdapterReqInput,
    UnloadLoRAAdapterReqOutput,
Chayenne's avatar
Chayenne committed
106
107
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
108
109
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
110
111
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
112
)
113
from sglang.srt.managers.mm_utils import init_embedding_cache
114
115
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
Mick's avatar
Mick committed
116
    MultimodalInputs,
117
118
    Req,
    ScheduleBatch,
119
    global_server_args_dict,
120
)
121
122
123
124
125
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
fzyzcjy's avatar
fzyzcjy committed
126
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
127
128
129
from sglang.srt.managers.scheduler_output_processor_mixin import (
    SchedulerOutputProcessorMixin,
)
130
from sglang.srt.managers.session_controller import Session
131
from sglang.srt.managers.tp_worker import TpModelWorker
132
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
133
from sglang.srt.managers.utils import validate_input_length
tarinkk's avatar
tarinkk committed
134
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
135
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
136
from sglang.srt.mem_cache.radix_cache import RadixCache
Hanming Lu's avatar
Hanming Lu committed
137
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
138
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
Lianmin Zheng's avatar
Lianmin Zheng committed
139
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
140
from sglang.srt.reasoning_parser import ReasoningParser
141
from sglang.srt.server_args import PortArgs, ServerArgs
142
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
143
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
144
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
145
from sglang.srt.utils import (
146
    DeepEPMode,
147
    DynamicGradMode,
148
    broadcast_pyobj,
fzyzcjy's avatar
fzyzcjy committed
149
    configure_gc_logger,
150
    configure_logger,
Lianmin Zheng's avatar
Lianmin Zheng committed
151
    disable_request_logging,
152
    get_available_gpu_memory,
153
    get_bool_env_var,
154
    get_zmq_socket,
155
    is_cpu,
Lianmin Zheng's avatar
Lianmin Zheng committed
156
    kill_itself_when_parent_died,
157
    point_to_point_pyobj,
158
    pyspy_dump_schedulers,
159
160
    require_mlp_sync,
    require_mlp_tp_gather,
161
    set_gpu_proc_affinity,
162
163
164
    set_random_seed,
    suppress_other_loggers,
)
165
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
166
167
168

logger = logging.getLogger(__name__)

169
# Test retract decode for debugging purposes
170
171
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
172
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
173

174
175
_is_cpu = is_cpu()

176

177
178
@dataclass
class GenerationBatchResult:
179
180
181
    logits_output: Optional[LogitsProcessorOutput]
    pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
    next_token_ids: Optional[List[int]]
182
183
    extend_input_len_per_req: List[int]
    extend_logprob_start_len_per_req: List[int]
184
    bid: int
185
    can_run_cuda_graph: bool
186
187
188
189
190
191
192
193


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


194
195
196
197
198
199
200
201
202
203
204
205
class KvMetrics:
    def __init__(self):
        self.request_active_slots = None
        self.request_total_slots = None
        self.kv_active_blocks = None
        self.kv_total_blocks = None
        self.num_requests_waiting = None
        self.gpu_cache_usage_perc = None
        self.gpu_prefix_cache_hit_rate = None
        self.data_parallel_rank = None


206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
class IdleSleeper:
    """
    In setups which have long inactivity periods it is desirable to reduce
    system power consumption when sglang does nothing. This would lead not only
    to power savings, but also to more CPU thermal headroom when a request
    eventually comes. This is important in cases when multiple GPUs are connected
    as each GPU would otherwise pin one thread at 100% CPU usage.

    The simplest solution is to use zmq.Poller on all sockets that may receive
    data that needs handling immediately.
    """

    def __init__(self, sockets):
        self.poller = zmq.Poller()
        for s in sockets:
            self.poller.register(s, zmq.POLLIN)

    def maybe_sleep(self):
        self.poller.poll(1000)


Byron Hsu's avatar
Byron Hsu committed
227
228
229
230
231
class Scheduler(
    SchedulerOutputProcessorMixin,
    SchedulerDisaggregationDecodeMixin,
    SchedulerDisaggregationPrefillMixin,
):
232
233
234
235
236
237
238
239
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
240
        pp_rank: int,
241
        dp_rank: Optional[int],
242
243
    ):
        # Parse args
244
        self.server_args = server_args
245
        self.tp_rank = tp_rank
246
        self.pp_rank = pp_rank
247
        self.dp_rank = dp_rank
248
        self.tp_size = server_args.tp_size
249
250
        self.pp_size = server_args.pp_size
        self.dp_size = server_args.dp_size
251
        self.schedule_policy = server_args.schedule_policy
252
        self.enable_lora = server_args.enable_lora
253
        self.max_loras_per_batch = server_args.max_loras_per_batch
254
        self.enable_overlap = not server_args.disable_overlap_schedule
255
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
256
        self.enable_metrics = server_args.enable_metrics
257
258
259
        self.enable_metrics_for_all_schedulers = (
            server_args.enable_metrics_for_all_schedulers
        )
260
        self.enable_kv_cache_events = server_args.kv_events_config is not None
261
        self.stream_interval = server_args.stream_interval
262
263
264
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
265
266
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
267
        self.enable_hicache_storage = server_args.hicache_storage_backend is not None
Lianmin Zheng's avatar
Lianmin Zheng committed
268
        self.page_size = server_args.page_size
269
270
        self.dp_size = server_args.dp_size
        self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
271
272
273
274
275
276
277
278
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

279
280
        # Init inter-process communication
        context = zmq.Context(2)
281
282
        self.idle_sleeper = None

283
        if self.pp_rank == 0 and self.attn_tp_rank == 0:
284
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
285
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
286
            )
287
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
288
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
289
            )
290

291
            if server_args.skip_tokenizer_init:
292
                # Directly send to the TokenizerManager
293
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
294
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
295
296
                )
            else:
297
                # Send to the DetokenizerManager
298
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
299
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
300
                )
301
302
303
304

            self.recv_from_rpc = get_zmq_socket(
                context, zmq.DEALER, port_args.rpc_ipc_name, False
            )
305
306
307
308
309
310
311
            if self.server_args.sleep_on_idle:
                self.idle_sleeper = IdleSleeper(
                    [
                        self.recv_from_tokenizer,
                        self.recv_from_rpc,
                    ]
                )
312
        else:
313
            self.recv_from_tokenizer = None
314
            self.recv_from_rpc = None
315
316
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
317

318
319
320
321
322
        if self.current_scheduler_metrics_enabled():
            self.send_metrics_from_scheduler = get_zmq_socket(
                context, zmq.PUSH, port_args.metrics_ipc_name, False
            )

323
        # Init tokenizer
324
        self.init_tokenizer()
325

326
327
328
329
330
331
332
333
334
        # Set reasoning_parser and think_end_id if --reasoning_parser is enabled
        if self.server_args.reasoning_parser and self.tokenizer:
            reasoning_parser = ReasoningParser(
                model_type=self.server_args.reasoning_parser, stream_reasoning=False
            )
            self.tokenizer.think_end_id = self.tokenizer.encode(
                reasoning_parser.detector.think_end_token, add_special_tokens=False
            )[0]

335
336
337
338
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
339

340
        # Launch a tensor parallel worker
341
        if self.enable_overlap:
342
            TpWorkerClass = TpModelWorkerClient
343
344
        else:
            TpWorkerClass = TpModelWorker
345

346
        self.tp_worker = TpWorkerClass(
347
            server_args=server_args,
348
349
            gpu_id=gpu_id,
            tp_rank=tp_rank,
350
            pp_rank=pp_rank,
351
            dp_rank=dp_rank,
352
            nccl_port=port_args.nccl_port,
353
        )
354

355
        # Launch a draft worker for speculative decoding
356
357
358
359
360
361
362
363
364
365
366
367
368
369
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
        else:
            self.draft_worker = None

370
        # Get token and memory info from the model worker
371
372
373
374
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
375
            self.max_queued_requests,
376
            self.max_req_len,
377
378
            self.max_req_input_len,
            self.random_seed,
379
            self.device,
380
381
382
383
384
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
385
386
387
388
389
390
391
392
        if global_server_args_dict["max_micro_batch_size"] is None:
            global_server_args_dict["max_micro_batch_size"] = max(
                self.max_running_requests // server_args.pp_size, 1
            )

        self.tp_group = self.tp_worker.get_tp_group()
        self.tp_cpu_group = self.tp_group.cpu_group
        self.attn_tp_group = self.tp_worker.get_attention_tp_group()
393
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
394
395
396
        self.pp_group = get_pp_group()
        self.world_group = get_world_group()

397
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
398
        global_server_args_dict.update(worker_global_server_args_dict)
399
        set_random_seed(self.random_seed)
400

Hanming Lu's avatar
Hanming Lu committed
401
402
403
404
405
406
407
408
        # Hybrid
        self.is_hybrid = self.tp_worker.is_hybrid
        if self.is_hybrid:
            self.sliding_window_size = self.tp_worker.sliding_window_size
            self.full_tokens_per_layer, self.swa_tokens_per_layer = (
                self.tp_worker.get_tokens_per_layer_info()
            )

409
        # Print debug info
410
        if tp_rank == 0:
411
412
413
            avail_mem = get_available_gpu_memory(
                self.device, self.gpu_id, empty_cache=False
            )
414
415
416
417
418
            logger.info(
                f"max_total_num_tokens={self.max_total_num_tokens}, "
                f"chunked_prefill_size={server_args.chunked_prefill_size}, "
                f"max_prefill_tokens={self.max_prefill_tokens}, "
                f"max_running_requests={self.max_running_requests}, "
419
420
                f"context_len={self.model_config.context_len}, "
                f"available_gpu_mem={avail_mem:.2f} GB"
421
            )
422

Lianmin Zheng's avatar
Lianmin Zheng committed
423
        # Init memory pool and cache
424
        self.init_memory_pool_and_cache()
425
426
427

        # Init running status
        self.waiting_queue: List[Req] = []
428
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
429
        self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
430
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
431
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
432
        # The last forward batch
433
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
434
435
        self.forward_ct = 0
        self.forward_ct_decode = 0
436
        self.num_generated_tokens = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
437
        self.last_prefill_tokens = 0
438
439
        self.last_decode_stats_tic = time.perf_counter()
        self.last_prefill_stats_tic = time.perf_counter()
440
        self.return_health_check_ct = 0
441
442
443
444
445
        self.num_retracted_reqs: int = 0
        self.num_paused_reqs: int = 0
        self.kv_transfer_speed_gb_s: float = 0.0
        self.kv_transfer_latency_ms: float = 0.0
        self.sessions: Dict[str, Session] = {}
446
        self.current_stream = torch.get_device_module(self.device).current_stream()
447
448
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
449
        self.forward_sleep_time = None
450

451
452
        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
453
454
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
455
        self.chunked_req = None
456
457
458
459
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
460
        # Init the grammar backend for constrained generation
461
        self.grammar_queue: List[Req] = []
462
        if not server_args.skip_tokenizer_init:
463
            self.grammar_backend = create_grammar_backend(
464
465
466
467
                server_args,
                self.tokenizer,
                self.model_config.vocab_size,
                self.model_config.hf_eos_token_id,
468
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
469
470
        else:
            self.grammar_backend = None
471

472
        # Init schedule policy and new token estimation
473
        self.policy = SchedulePolicy(
Lianmin Zheng's avatar
Lianmin Zheng committed
474
475
476
            self.schedule_policy,
            self.tree_cache,
            self.enable_hierarchical_cache,
477
        )
478
479
480
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
481
482
        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
483
484
            * server_args.schedule_conservativeness,
            1.0,
485
        )
486
487
488
489
490
491
492
493
494
495
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
496
497
498
499
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
500
        self.parent_process = psutil.Process().parent()
501
502

        # Init memory saver, profiler and metric stats
503
504
505
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )
506
        self.init_profier()
507

fzyzcjy's avatar
fzyzcjy committed
508
509
510
511
512
513
        self.input_blocker = (
            SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
            if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
            else None
        )

514
        # Init metrics stats
515
        self.init_metrics(tp_rank, pp_rank, dp_rank)
516
        self.init_kv_events(server_args.kv_events_config)
517

518
519
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
520
521
522
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
523
                (FlushCacheReqInput, self.flush_cache_wrapped),
524
                (AbortReq, self.abort_request),
525
526
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
527
528
529
530
531
532
533
534
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
535
536
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
537
                (SlowDownReqInput, self.slow_down),
538
                (ProfileReq, self.profile),
539
                (GetInternalStateReq, self.get_internal_state),
540
                (SetInternalStateReq, self.set_internal_state),
541
                (RpcReqInput, self.handle_rpc_request),
542
                (ExpertDistributionReq, self.expert_distribution_handle),
543
544
                (LoadLoRAAdapterReqInput, self.load_lora_adapter),
                (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
545
546
547
            ]
        )

548
        # Init disaggregation
Byron Hsu's avatar
Byron Hsu committed
549
550
551
552
553
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
        self.init_disaggregation()

fzyzcjy's avatar
fzyzcjy committed
554
555
556
        if get_bool_env_var("SGLANG_GC_LOG"):
            configure_gc_logger()

557
558
559
    def current_scheduler_metrics_enabled(self):
        return self.attn_tp_rank == 0 or self.enable_metrics_for_all_schedulers

560
561
562
563
    def maybe_sleep_on_idle(self):
        if self.idle_sleeper is not None:
            self.idle_sleeper.maybe_sleep()

564
565
    def init_tokenizer(self):
        server_args = self.server_args
Lianmin Zheng's avatar
Lianmin Zheng committed
566

567
        self.model_config = ModelConfig.from_server_args(server_args)
568
        self.is_generation = self.model_config.is_generation
569

570
571
572
573
574
575
576
577
578
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if self.model_config.is_multimodal:
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
579
                    use_fast=not server_args.disable_fast_image_processor,
580
                )
xm:D's avatar
xm:D committed
581
                self.tokenizer = get_tokenizer_from_processor(self.processor)
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )

    def init_memory_pool_and_cache(self):
        server_args = self.server_args

        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            self.tp_worker.get_memory_pool()
        )

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
Hanming Lu's avatar
Hanming Lu committed
601
            if self.is_hybrid:
tarinkk's avatar
tarinkk committed
602
603
604
605
                ChunkCacheClass = SWAChunkCache
            else:
                ChunkCacheClass = ChunkCache
            self.tree_cache = ChunkCacheClass(
606
607
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
608
                page_size=self.page_size,
609
610
611
612
613
614
            )
        else:
            if self.enable_hierarchical_cache:
                self.tree_cache = HiRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
615
616
617
618
619
                    tp_cache_group=(
                        self.attn_tp_cpu_group
                        if self.server_args.enable_dp_attention
                        else self.tp_cpu_group
                    ),
620
                    page_size=self.page_size,
621
                    hicache_ratio=server_args.hicache_ratio,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
622
623
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
624
625
626
627
628
629
                    hicache_io_backend=(
                        "direct"
                        if server_args.attention_backend
                        == "fa3"  # hot fix for incompatibility
                        else server_args.hicache_io_backend
                    ),
630
                    hicache_storage_backend=server_args.hicache_storage_backend,
631
                )
632
633
634
                self.tp_worker.register_hicache_layer_transfer_counter(
                    self.tree_cache.cache_controller.layer_done_counter
                )
Hanming Lu's avatar
Hanming Lu committed
635
636
637
638
639
640
641
642
643
644
645
            elif self.is_hybrid:
                assert (
                    self.server_args.disaggregation_mode == "null"
                ), "Hybrid mode does not support disaggregation yet"
                self.tree_cache = SWARadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                    sliding_window_size=self.sliding_window_size,
                    page_size=self.page_size,
                    disable=server_args.disable_radix_cache,
                )
646

647
648
649
650
            else:
                self.tree_cache = RadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
651
                    page_size=self.page_size,
652
                    disable=server_args.disable_radix_cache,
653
                    enable_kv_cache_events=self.enable_kv_cache_events,
654
655
656
657
658
659
660
661
662
663
664
665
                )

        self.decode_mem_cache_buf_multiplier = (
            1
            if self.spec_algorithm.is_none()
            else (
                server_args.speculative_num_draft_tokens
                + (
                    server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                )
            )
666
        )
667

668
669
670
        embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
        init_embedding_cache(embedding_cache_size * 1024 * 1024)

671
672
673
674
675
    def init_profier(self):
        self.torch_profiler = None
        self.torch_profiler_output_dir: Optional[str] = None
        self.profiler_activities: Optional[List[str]] = None
        self.profile_id: Optional[str] = None
676
        self.profiler_start_forward_ct: Optional[int] = None
677
678
679
680
681
682
683
684
685
686
        self.profiler_target_forward_ct: Optional[int] = None
        self.profiler_target_prefill_ct: Optional[int] = None
        self.profiler_target_decode_ct: Optional[int] = None
        self.profiler_prefill_ct: Optional[int] = None
        self.profiler_decode_ct: Optional[int] = None
        self.profile_by_stage: bool = False
        self.profile_steps: Optional[int] = None
        self.profile_in_progress: bool = False
        self.rpd_profiler = None

687
    def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
688
        self.last_gen_throughput: float = 0.0
Lianmin Zheng's avatar
Lianmin Zheng committed
689
        self.last_input_throughput: float = 0.0
690
691
692
693
694
        self.step_time_dict = defaultdict(list)  # Dict[batch size -> step time]
        self.spec_num_total_accepted_tokens = 0
        self.spec_num_total_forward_ct = 0
        self.cum_spec_accept_length = 0
        self.cum_spec_accept_count = 0
695
        self.total_retracted_reqs = 0
696
697
698
        self.stats = SchedulerStats()
        if self.enable_metrics:
            engine_type = "unified"
699
700
701
702
703
704
705
706
707
            labels = {
                "model_name": self.server_args.served_model_name,
                "engine_type": engine_type,
                "tp_rank": tp_rank,
                "pp_rank": pp_rank,
            }
            if dp_rank is not None:
                labels["dp_rank"] = dp_rank
            self.metrics_collector = SchedulerMetricsCollector(labels=labels)
Lianmin Zheng's avatar
Lianmin Zheng committed
708

709
710
    def init_kv_events(self, kv_events_config: Optional[str]):
        if self.enable_kv_cache_events:
711
712
713
            self.kv_event_publisher = EventPublisherFactory.create(
                kv_events_config, self.attn_dp_rank
            )
714

Byron Hsu's avatar
Byron Hsu committed
715
    def init_disaggregation(self):
716
717
718
719
        self.transfer_backend = TransferBackend(
            self.server_args.disaggregation_transfer_backend
        )

Byron Hsu's avatar
Byron Hsu committed
720
721
722
723
        if (
            self.disaggregation_mode == DisaggregationMode.DECODE
        ):  # *2 for the headroom.
            buffer_size = (self.req_to_token_pool.size) * 2
Byron Hsu's avatar
Byron Hsu committed
724
            self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
Byron Hsu's avatar
Byron Hsu committed
725
726
                buffer_size
            )
727
728
            self.disagg_metadata_buffers = MetadataBuffers(
                buffer_size,
729
730
                hidden_size=self.model_config.hf_text_config.hidden_size,
                dtype=self.model_config.dtype,
731
732
                custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
            )
Byron Hsu's avatar
Byron Hsu committed
733
734
735

            # The decode requests polling kv cache
            self.disagg_decode_transfer_queue = DecodeTransferQueue(
736
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
737
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
738
                tp_rank=self.tp_rank,
739
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
740
741
                scheduler=self,
                tree_cache=self.tree_cache,
Byron Hsu's avatar
Byron Hsu committed
742
743
744
745
746
747
            )

            # The decode requests pending for pre-allocation
            self.disagg_decode_prealloc_queue = DecodePreallocQueue(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Byron Hsu's avatar
Byron Hsu committed
748
749
750
751
752
                draft_token_to_kv_pool=(
                    None
                    if self.draft_worker is None
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
753
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
754
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
755
756
757
                scheduler=self,
                transfer_queue=self.disagg_decode_transfer_queue,
                tree_cache=self.tree_cache,
758
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
759
760
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
761
762
                dp_size=self.server_args.dp_size,
                gpu_id=self.gpu_id,
Byron Hsu's avatar
Byron Hsu committed
763
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
764
765
                max_total_num_tokens=self.max_total_num_tokens,
                prefill_pp_size=self.server_args.disaggregation_prefill_pp,
766
                num_reserved_decode_tokens=self.server_args.num_reserved_decode_tokens,
767
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
768
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
769

Byron Hsu's avatar
Byron Hsu committed
770
771
772
        elif self.disaggregation_mode == DisaggregationMode.PREFILL:
            # *2 for the headroom.
            buffer_size = self.max_running_requests * 2
Byron Hsu's avatar
Byron Hsu committed
773
            self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
Byron Hsu's avatar
Byron Hsu committed
774
775
                buffer_size
            )
776
777
            self.disagg_metadata_buffers = MetadataBuffers(
                buffer_size,
778
779
                hidden_size=self.model_config.hf_text_config.hidden_size,
                dtype=self.model_config.dtype,
780
781
                custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
            )
Byron Hsu's avatar
Byron Hsu committed
782

Liangsheng Yin's avatar
Liangsheng Yin committed
783
            self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
Byron Hsu's avatar
Byron Hsu committed
784
                token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
Byron Hsu's avatar
Byron Hsu committed
785
786
787
788
789
                draft_token_to_kv_pool=(
                    None
                    if self.draft_worker is None
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
790
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
791
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
792
793
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
Byron Hsu's avatar
Byron Hsu committed
794
                gpu_id=self.gpu_id,
Byron Hsu's avatar
Byron Hsu committed
795
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
796
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
797
798
799
                max_total_num_tokens=self.max_total_num_tokens,
                decode_tp_size=self.server_args.disaggregation_decode_tp,
                decode_dp_size=self.server_args.disaggregation_decode_dp,
800
                scheduler=self,
Byron Hsu's avatar
Byron Hsu committed
801
802
803
                pp_rank=self.pp_rank,
                pp_size=self.pp_size,
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
804
805
            )
            # The prefill requests that are in the middle of kv sending
806
            self.disagg_prefill_inflight_queue: List[Req] = []
Byron Hsu's avatar
Byron Hsu committed
807

808
    @DynamicGradMode()
809
    def event_loop_normal(self):
810
        """A normal scheduler loop."""
811
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
812
813
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
814

815
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
816
            self.cur_batch = batch
817
818
819
820

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
821
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
822
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
823
                self.check_memory()
Hanming Lu's avatar
Hanming Lu committed
824
                self.check_tree_cache()
825
                self.new_token_ratio = self.init_new_token_ratio
826
                self.maybe_sleep_on_idle()
827
828

            self.last_batch = batch
829

830
    @DynamicGradMode()
Lianmin Zheng's avatar
Lianmin Zheng committed
831
    def event_loop_overlap(self):
832
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
833
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
834
835
836
837
838
839
840

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
841

Lianmin Zheng's avatar
Lianmin Zheng committed
842
            if batch:
843
                batch.launch_done = threading.Event()
Lianmin Zheng's avatar
Lianmin Zheng committed
844
                result = self.run_batch(batch)
845
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
846

847
                if self.last_batch is None:
848
                    # Create a dummy first batch to start the pipeline for overlap schedule.
849
850
851
852
853
854
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
855
                    self.process_batch_result(tmp_batch, None, batch.launch_done)
856

Lianmin Zheng's avatar
Lianmin Zheng committed
857
            if self.last_batch:
858
                # Process the results of the last batch
859
                tmp_batch, tmp_result = self.result_queue.popleft()
860
861
862
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
863
864
865
866
                # NOTE: we should use current launched batch's launch_done event Instead of the last batch's
                self.process_batch_result(
                    tmp_batch, tmp_result, batch.launch_done if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
867
            elif batch is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
868
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
869
                self.check_memory()
Hanming Lu's avatar
Hanming Lu committed
870
                self.check_tree_cache()
871
                self.new_token_ratio = self.init_new_token_ratio
872
                self.maybe_sleep_on_idle()
Lianmin Zheng's avatar
Lianmin Zheng committed
873
874
875

            self.last_batch = batch

876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
    @DynamicGradMode()
    def event_loop_pp(self):
        """A non-overlap scheduler loop for pipeline parallelism."""
        mbs = [None] * self.pp_size
        last_mbs = [None] * self.pp_size
        self.running_mbs = [
            ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
        ]
        bids = [None] * self.pp_size
        pp_outputs: Optional[PPProxyTensors] = None
        while True:
            server_is_idle = True
            for mb_id in range(self.pp_size):
                self.running_batch = self.running_mbs[mb_id]
                self.last_batch = last_mbs[mb_id]

                recv_reqs = self.recv_requests()
                self.process_input_requests(recv_reqs)
                mbs[mb_id] = self.get_next_batch_to_run()
                self.running_mbs[mb_id] = self.running_batch

                self.cur_batch = mbs[mb_id]
                if self.cur_batch:
                    server_is_idle = False
                    result = self.run_batch(self.cur_batch)

902
                # (last rank) send the outputs to the next step
903
904
905
906
907
908
                if self.pp_group.is_last_rank:
                    if self.cur_batch:
                        next_token_ids, bids[mb_id] = (
                            result.next_token_ids,
                            result.bid,
                        )
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
                        if self.cur_batch.return_logprob:
                            pp_outputs = PPProxyTensors(
                                {
                                    "next_token_ids": next_token_ids,
                                    "extend_input_len_per_req": result.extend_input_len_per_req,
                                    "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
                                }
                                | (
                                    {
                                        f"logits_output.{k}": v
                                        for k, v in result.logits_output.__dict__.items()
                                    }
                                    if result.logits_output is not None
                                    else {}
                                )
                            )
                        else:
                            pp_outputs = PPProxyTensors(
                                {
                                    "next_token_ids": next_token_ids,
                                }
                            )
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
                        # send the output from the last round to let the next stage worker run post processing
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                # receive outputs and post-process (filter finished reqs) the coming microbatch
                next_mb_id = (mb_id + 1) % self.pp_size
                next_pp_outputs = None
                if mbs[next_mb_id] is not None:
                    next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
                        self.pp_group.recv_tensor_dict(
                            all_gather_group=self.attn_tp_group
                        )
                    )
                    mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
947
948
949
950
951
952
953
954
955
                    logits_output_args = {
                        k[len("logits_output.") :]: v
                        for k, v in next_pp_outputs.tensors.items()
                        if k.startswith("logits_output.")
                    }
                    if len(logits_output_args) > 0:
                        logits_output = LogitsProcessorOutput(**logits_output_args)
                    else:
                        logits_output = None
956
                    output_result = GenerationBatchResult(
957
                        logits_output=logits_output,
958
959
                        pp_hidden_states_proxy_tensors=None,
                        next_token_ids=next_pp_outputs["next_token_ids"],
960
961
962
963
964
965
                        extend_input_len_per_req=next_pp_outputs.tensors.get(
                            "extend_input_len_per_req", None
                        ),
                        extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
                            "extend_logprob_start_len_per_req", None
                        ),
966
                        bid=bids[next_mb_id],
967
                        can_run_cuda_graph=result.can_run_cuda_graph,
968
969
970
971
                    )
                    self.process_batch_result(mbs[next_mb_id], output_result)
                    last_mbs[next_mb_id] = mbs[next_mb_id]

972
                # (not last rank)
973
974
975
                if not self.pp_group.is_last_rank:
                    if self.cur_batch:
                        bids[mb_id] = result.bid
976
977
                    # carry the outputs to the next stage
                    # send the outputs from the last round to let the next stage worker run post processing
978
979
980
981
982
983
984
                    if pp_outputs:
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                    # send out reqs to the next stage
985
                    dp_offset = self.attn_dp_rank * self.attn_tp_size
986
987
988
989
                    if self.attn_tp_rank == 0:
                        point_to_point_pyobj(
                            recv_reqs,
                            self.pp_rank * self.tp_size + dp_offset,
990
                            self.world_group.device_group,
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
                            self.pp_rank * self.tp_size + dp_offset,
                            (self.pp_rank + 1) * self.tp_size + dp_offset,
                        )

                    # send out proxy tensors to the next stage
                    if self.cur_batch:
                        self.pp_group.send_tensor_dict(
                            result.pp_hidden_states_proxy_tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                pp_outputs = next_pp_outputs

            # When the server is idle, self-check and re-init some states
            if server_is_idle:
                self.check_memory()
Hanming Lu's avatar
Hanming Lu committed
1007
                self.check_tree_cache()
1008
                self.new_token_ratio = self.init_new_token_ratio
1009
                self.maybe_sleep_on_idle()
1010

1011
1012
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
        if self.pp_rank == 0:
            if self.attn_tp_rank == 0:
                recv_reqs = []

                while True:
                    try:
                        recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_req)

                while True:
                    try:
                        recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_rpc)
            else:
                recv_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1032
        else:
1033
            if self.attn_tp_rank == 0:
1034
                dp_offset = self.attn_dp_rank * self.attn_tp_size
1035
1036
1037
                recv_reqs = point_to_point_pyobj(
                    [],
                    self.pp_rank * self.tp_size + dp_offset,
1038
                    self.world_group.device_group,
1039
1040
1041
1042
1043
                    (self.pp_rank - 1) * self.tp_size + dp_offset,
                    self.pp_rank * self.tp_size + dp_offset,
                )
            else:
                recv_reqs = None
1044

fzyzcjy's avatar
fzyzcjy committed
1045
1046
1047
        if self.input_blocker is not None:
            recv_reqs = self.input_blocker.handle(recv_reqs)

1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                work_reqs = broadcast_pyobj(
                    work_reqs,
1071
                    self.attn_tp_group.rank,
1072
                    self.attn_tp_cpu_group,
1073
                    src=self.attn_tp_group.ranks[0],
1074
1075
1076
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
1077
1078
1079
1080
                    control_reqs,
                    self.tp_group.rank,
                    self.tp_cpu_group,
                    src=self.tp_group.ranks[0],
1081
1082
1083
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
1084
1085
1086
1087
1088
1089
            recv_reqs = broadcast_pyobj(
                recv_reqs,
                self.tp_group.rank,
                self.tp_cpu_group,
                src=self.tp_group.ranks[0],
            )
1090
1091
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
1092
    def process_input_requests(self, recv_reqs: List):
1093
        for recv_req in recv_reqs:
1094
1095
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
Lianmin Zheng's avatar
Lianmin Zheng committed
1096
                self.chunked_req is not None or not self.running_batch.is_empty()
1097
1098
1099
1100
            ):
                self.return_health_check_ct += 1
                continue

1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
            # If it is a work request, accept or reject the request based on the request queue size.
            if is_work_request(recv_req):
                if len(self.waiting_queue) + 1 > self.max_queued_requests:
                    abort_req = AbortReq(
                        recv_req.rid,
                        finished_reason={
                            "type": "abort",
                            "status_code": HTTPStatus.SERVICE_UNAVAILABLE,
                            "message": "The request queue is full.",
                        },
                    )
                    self.send_to_tokenizer.send_pyobj(abort_req)
                    continue
1114
            output = self._request_dispatcher(recv_req)
1115
            if output is not None:
1116
1117
1118
1119
1120
                if isinstance(output, RpcReqOutput):
                    if self.recv_from_rpc is not None:
                        self.recv_from_rpc.send_pyobj(output)
                else:
                    self.send_to_tokenizer.send_pyobj(output)
1121
1122
1123
1124
1125

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
1126
        # Create a new request
1127
1128
1129
1130
1131
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
1132
1133
1134
1135
1136
1137
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

1138
1139
1140
1141
            if recv_req.bootstrap_port is None:
                # Use default bootstrap port
                recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port

1142
1143
1144
1145
1146
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
1147
1148
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
1149
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
1150
                stream=recv_req.stream,
1151
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
1152
                input_embeds=recv_req.input_embeds,
Lianmin Zheng's avatar
Lianmin Zheng committed
1153
                custom_logit_processor=recv_req.custom_logit_processor,
1154
                return_hidden_states=recv_req.return_hidden_states,
1155
                eos_token_ids=self.model_config.hf_eos_token_id,
1156
                bootstrap_host=recv_req.bootstrap_host,
1157
                bootstrap_port=recv_req.bootstrap_port,
1158
                bootstrap_room=recv_req.bootstrap_room,
1159
                data_parallel_rank=recv_req.data_parallel_rank,
1160
                vocab_size=self.model_config.vocab_size,
1161
1162
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
1163

1164
1165
1166
            if self.disaggregation_mode != DisaggregationMode.NULL:
                # Invalid request for disaggregated mode
                if recv_req.bootstrap_room is None:
1167
                    error_msg = (
1168
1169
1170
                        f"Invalid request: Disaggregated request received without "
                        f"boostrap room id. {req.rid=}"
                    )
1171
1172
                    logger.error(error_msg)
                    prepare_abort(req, error_msg)
1173
1174
1175
                    self.stream_output([req], req.return_logprob)
                    return

1176
1177
1178
1179
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
1180
                req.set_finish_with_abort(
1181
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
1182
                )
1183
                self._add_request_to_queue(req)
1184
1185
                return
        else:
1186
1187
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
1188
            req = session.create_req(recv_req, self.tokenizer)
1189
            if isinstance(req.finished_reason, FINISH_ABORT):
1190
                self._add_request_to_queue(req)
1191
                return
1192

1193
        # Handle multimodal inputs
Mick's avatar
Mick committed
1194
1195
        if recv_req.mm_inputs is not None:
            image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
1196
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
1197
            req.origin_input_ids = self.pad_input_ids_func(
1198
                req.origin_input_ids, image_inputs
1199
            )
1200
            req.extend_image_inputs(image_inputs)
1201

1202
            if len(req.origin_input_ids) >= self.max_req_input_len:
1203
1204
1205
1206
1207
                req.set_finish_with_abort(
                    error_msg=(
                        "Multimodal prompt is too long after expanding multimodal tokens. "
                        f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                    )
1208
                )
1209
                self._add_request_to_queue(req)
1210
1211
                return

1212
        # Validate prompt length
1213
1214
1215
1216
1217
1218
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
1219
            req.set_finish_with_abort(error_msg)
1220
            self._add_request_to_queue(req)
1221
            return
1222

1223
        # Copy more attributes
1224
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
1225
1226
1227
1228
1229
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(req.origin_input_ids) - 1
        else:
            req.logprob_start_len = recv_req.logprob_start_len

1230
        if req.logprob_start_len >= len(req.origin_input_ids):
1231
            error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
1232
            req.logprob_start_len = len(req.origin_input_ids) - 1
1233
            req.set_finish_with_abort(error_msg)
1234
1235
1236
            self._add_request_to_queue(req)
            return

1237
1238
1239
1240
1241
1242
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
1243
            self.max_req_len - len(req.origin_input_ids) - 1,
1244
1245
        )

1246
1247
1248
1249
1250
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
1251
            or req.sampling_params.ebnf is not None
1252
            or req.sampling_params.structural_tag is not None
1253
1254
1255
1256
1257
1258
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
1259
1260
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
1261
1262
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
1263

1264
1265
1266
1267
1268
            value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
            req.grammar = value

            if not cache_hit:
                req.grammar_key = key
1269
                add_to_grammar_queue = True
1270
1271
1272
1273
            else:
                if value is INVALID_GRAMMAR_OBJ:  # We hit a cached invalid grammar.
                    error_msg = f"Invalid grammar request with cache hit: {key=}"
                    req.set_finish_with_abort(error_msg)
1274
1275

        if add_to_grammar_queue:
1276
            req.queue_time_start = time.perf_counter()
1277
1278
            self.grammar_queue.append(req)
        else:
1279
1280
1281
            self._add_request_to_queue(req)

    def _add_request_to_queue(self, req: Req):
1282
        req.queue_time_start = time.perf_counter()
Byron Hsu's avatar
Byron Hsu committed
1283
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
Byron Hsu's avatar
Byron Hsu committed
1284
1285
1286
            self.disagg_prefill_bootstrap_queue.add(
                req, self.model_config.num_key_value_heads
            )
Byron Hsu's avatar
Byron Hsu committed
1287
1288
1289
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.add(req)
        else:
1290
1291
1292
1293
1294
1295
1296
1297
1298
            if self.enable_hicache_storage:
                req.init_next_round_input(self.tree_cache)
                last_hash = req.last_host_node.get_last_hash_value()
                matched_len = len(req.prefix_indices) + req.host_hit_length
                if (matched_len > 0 and last_hash is not None) or matched_len == 0:
                    new_input_tokens = req.fill_ids[matched_len:]
                    self.tree_cache.prefetch_from_storage(
                        req.rid, req.last_host_node, new_input_tokens, last_hash
                    )
Byron Hsu's avatar
Byron Hsu committed
1299
1300
            self.waiting_queue.append(req)

1301
    def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
1302
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
Byron Hsu's avatar
Byron Hsu committed
1303
1304
1305
            self.disagg_prefill_bootstrap_queue.extend(
                reqs, self.model_config.num_key_value_heads
            )
1306
1307
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            # If this is a decode server, we put the request to the decode pending prealloc queue
1308
            self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
Byron Hsu's avatar
Byron Hsu committed
1309
1310
        else:
            self.waiting_queue.extend(reqs)
1311
1312
1313

    def handle_embedding_request(
        self,
1314
        recv_req: TokenizedEmbeddingReqInput,
1315
1316
1317
1318
1319
1320
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
woodx's avatar
woodx committed
1321
            token_type_ids=recv_req.token_type_ids,
1322
1323
1324
        )
        req.tokenizer = self.tokenizer

1325
1326
        # Handle multimodal inputs
        if recv_req.image_inputs is not None:
Mick's avatar
Mick committed
1327
            image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
1328
1329
1330
1331
1332
1333
1334
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
            req.origin_input_ids = self.pad_input_ids_func(
                req.origin_input_ids, image_inputs
            )
            req.extend_image_inputs(image_inputs)

            if len(req.origin_input_ids) >= self.max_req_input_len:
1335
1336
1337
1338
1339
                req.set_finish_with_abort(
                    error_msg=(
                        "Multimodal prompt is too long after expanding multimodal tokens. "
                        f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                    )
1340
                )
1341
                self._add_request_to_queue(req)
1342
1343
                return

1344
        # Validate prompts length
1345
        error_msg = validate_input_length(
1346
1347
1348
1349
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
1350
        if error_msg:
1351
            self._add_request_to_queue(req)
1352
            return
1353

1354
1355
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
1356
        self._add_request_to_queue(req)
1357

1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
    def _emit_kv_metrics(self):
        kv_metrics = KvMetrics()
        kv_metrics.request_active_slots = self.stats.num_running_reqs
        kv_metrics.request_total_slots = self.max_running_requests
        kv_metrics.kv_active_blocks = int(
            self.stats.token_usage * self.max_total_num_tokens
        )
        kv_metrics.kv_total_blocks = self.max_total_num_tokens
        kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
        kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
        kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
        kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0

        if not self.send_metrics_from_scheduler.closed:
            self.send_metrics_from_scheduler.send_pyobj(kv_metrics)

1374
1375
1376
1377
    def log_prefill_stats(
        self,
        adder: PrefillAdder,
        can_run_list: List[Req],
1378
        running_bs: int,
1379
    ):
1380
1381
        gap_latency = time.perf_counter() - self.last_prefill_stats_tic
        self.last_prefill_stats_tic = time.perf_counter()
Liangsheng Yin's avatar
Liangsheng Yin committed
1382
1383
        self.last_input_throughput = self.last_prefill_tokens / gap_latency
        self.last_prefill_tokens = adder.log_input_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
1384

Hanming Lu's avatar
Hanming Lu committed
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
        if self.is_hybrid:
            (
                full_num_used,
                swa_num_used,
                full_token_usage,
                swa_token_usage,
                _,
                _,
                _,
                _,
            ) = self._get_swa_token_info()
            num_used = max(full_num_used, swa_num_used)
            token_usage = max(full_token_usage, swa_token_usage)
            token_msg = (
                f"full token usage: {full_token_usage:.2f}, "
                f"swa token usage: {swa_token_usage:.2f}, "
            )
        else:
            num_used, token_usage, _, _ = self._get_token_info()
            token_msg = f"token usage: {token_usage:.2f}, "
1405

1406
        num_new_seq = len(can_run_list)
1407
        f = (
1408
            f"Prefill batch. "
1409
            f"#new-seq: {num_new_seq}, "
1410
1411
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
Hanming Lu's avatar
Hanming Lu committed
1412
            f"{token_msg}"
1413
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
1414
1415
1416
1417

        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
            f += f"#queue-req: {len(self.waiting_queue)}, "
fzyzcjy's avatar
fzyzcjy committed
1418
            f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
1419
            f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
Liangsheng Yin's avatar
Liangsheng Yin committed
1420
        else:
Liangsheng Yin's avatar
Liangsheng Yin committed
1421
            f += f"#running-req: {running_bs}, "
1422
1423
            f += f"#queue-req: {len(self.waiting_queue)}, "

1424
        logger.info(f)
1425
1426

        if self.enable_metrics:
1427
1428
1429
1430
            total_tokens = adder.log_input_tokens + adder.log_hit_tokens

            cache_hit_rate = (
                adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
1431
            )
1432
1433
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
Hanming Lu's avatar
Hanming Lu committed
1434
            self.stats.token_usage = round(token_usage, 2)
1435
1436
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.stats.cache_hit_rate = cache_hit_rate
1437
1438
1439
1440
1441
1442

            total_queue_latency = 0
            for req in can_run_list:
                total_queue_latency += req.queue_time_end - req.queue_time_start
            self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq

1443
            self.metrics_collector.log_stats(self.stats)
1444
            self._emit_kv_metrics()
1445
        self._publish_kv_events()
1446

1447
1448
1449
    def log_decode_stats(
        self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
    ):
1450
1451
        batch = running_batch or self.running_batch

1452
1453
        gap_latency = time.perf_counter() - self.last_decode_stats_tic
        self.last_decode_stats_tic = time.perf_counter()
1454
1455
        self.last_gen_throughput = self.num_generated_tokens / gap_latency
        self.num_generated_tokens = 0
1456
        num_running_reqs = len(batch.reqs)
Hanming Lu's avatar
Hanming Lu committed
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
        if self.is_hybrid:
            (
                full_num_used,
                swa_num_used,
                full_token_usage,
                swa_token_usage,
                _,
                _,
                _,
                _,
            ) = self._get_swa_token_info()
            num_used = max(full_num_used, swa_num_used)
            token_usage = max(full_token_usage, swa_token_usage)
            token_msg = (
                f"#full token: {full_num_used}, "
                f"full token usage: {full_token_usage:.2f}, "
                f"#swa token: {swa_num_used}, "
                f"swa token usage: {swa_token_usage:.2f}, "
            )
        else:
            num_used, token_usage, _, _ = self._get_token_info()
            token_msg = f"#token: {num_used}, " f"token usage: {token_usage:.2f}, "
1479
1480
1481
1482
1483

        if RECORD_STEP_TIME:
            self.step_time_dict[num_running_reqs].append(
                gap_latency / self.server_args.decode_log_interval
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1484

Hanming Lu's avatar
Hanming Lu committed
1485
        msg = f"Decode batch. #running-req: {num_running_reqs}, {token_msg}"
Liangsheng Yin's avatar
Liangsheng Yin committed
1486

1487
        if self.spec_algorithm.is_none():
1488
            spec_accept_length = 0
1489
        else:
1490
            spec_accept_length = (
1491
1492
                self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
            )
1493
1494
            self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
            self.cum_spec_accept_count += self.spec_num_total_forward_ct
1495
            self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
1496
1497
1498
            msg += f"accept len: {spec_accept_length:.2f}, "

        if self.disaggregation_mode == DisaggregationMode.DECODE:
1499
            msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
1500
            msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
Liangsheng Yin's avatar
Liangsheng Yin committed
1501
1502

        msg += (
1503
            f"cuda graph: {can_run_cuda_graph}, "
Liangsheng Yin's avatar
Liangsheng Yin committed
1504
            f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
1505
            f"#queue-req: {len(self.waiting_queue)}, "
Liangsheng Yin's avatar
Liangsheng Yin committed
1506
        )
1507
1508

        logger.info(msg)
1509
1510
1511
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
Hanming Lu's avatar
Hanming Lu committed
1512
            self.stats.token_usage = round(token_usage, 2)
1513
1514
            self.stats.cache_hit_rate = 0.0
            self.stats.gen_throughput = self.last_gen_throughput
1515
            self.stats.num_queue_reqs = len(self.waiting_queue)
1516
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1517
            self.stats.spec_accept_length = spec_accept_length
1518
            self.stats.total_retracted_reqs = self.total_retracted_reqs
1519
            self.metrics_collector.log_stats(self.stats)
1520
            self._emit_kv_metrics()
1521
        self._publish_kv_events()
1522

Lianmin Zheng's avatar
Lianmin Zheng committed
1523
    def check_memory(self):
Hanming Lu's avatar
Hanming Lu committed
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
        if self.is_hybrid:
            (
                full_num_used,
                swa_num_used,
                _,
                _,
                full_available_size,
                full_evictable_size,
                swa_available_size,
                swa_evictable_size,
            ) = self._get_swa_token_info()
            memory_leak = full_num_used != 0 or swa_num_used != 0
            token_msg = (
                f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
                f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
            )
tarinkk's avatar
tarinkk committed
1540
        else:
Hanming Lu's avatar
Hanming Lu committed
1541
1542
1543
1544
1545
1546
            _, _, available_size, evictable_size = self._get_token_info()
            protected_size = self.tree_cache.protected_size()
            memory_leak = (available_size + evictable_size) != (
                self.max_total_num_tokens
                if not self.enable_hierarchical_cache
                else self.max_total_num_tokens - protected_size
Lianmin Zheng's avatar
Lianmin Zheng committed
1547
            )
Hanming Lu's avatar
Hanming Lu committed
1548
1549
1550
1551
            token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"

        if memory_leak:
            msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}"
Lianmin Zheng's avatar
Lianmin Zheng committed
1552
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1553

1554
1555
1556
1557
1558
1559
1560
1561
        if self.disaggregation_mode == DisaggregationMode.DECODE:
            req_total_size = (
                self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
            )
        else:
            req_total_size = self.req_to_token_pool.size

        if len(self.req_to_token_pool.free_slots) != req_total_size:
1562
            msg = (
1563
                "req_to_token_pool memory leak detected!"
1564
1565
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1566
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1567
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1568

1569
1570
        if (
            self.enable_metrics
1571
            and self.current_scheduler_metrics_enabled()
1572
            and time.perf_counter() > self.metrics_collector.last_log_time + 30
1573
1574
        ):
            # During idle time, also collect metrics every 30 seconds.
Hanming Lu's avatar
Hanming Lu committed
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
            if self.is_hybrid:
                (
                    full_num_used,
                    swa_num_used,
                    full_token_usage,
                    swa_token_usage,
                    _,
                    _,
                    _,
                    _,
                ) = self._get_swa_token_info()
                num_used = max(full_num_used, swa_num_used)
                token_usage = max(full_token_usage, swa_token_usage)
            else:
                num_used, token_usage, _, _ = self._get_token_info()
Lianmin Zheng's avatar
Lianmin Zheng committed
1590
            num_running_reqs = len(self.running_batch.reqs)
1591
1592
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
Hanming Lu's avatar
Hanming Lu committed
1593
            self.stats.token_usage = round(token_usage, 2)
1594
1595
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
1596
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1597
            self.metrics_collector.log_stats(self.stats)
1598
        self._publish_kv_events()
1599

Hanming Lu's avatar
Hanming Lu committed
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
    def check_tree_cache(self):
        if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache):
            self.tree_cache.sanity_check()

    def _get_token_info(self):
        available_size = self.token_to_kv_pool_allocator.available_size()
        evictable_size = self.tree_cache.evictable_size()
        num_used = self.max_total_num_tokens - (available_size + evictable_size)
        token_usage = num_used / self.max_total_num_tokens
        return num_used, token_usage, available_size, evictable_size

    def _get_swa_token_info(self):
        full_available_size = self.token_to_kv_pool_allocator.full_available_size()
        full_evictable_size = self.tree_cache.full_evictable_size()
        swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
        swa_evictable_size = self.tree_cache.swa_evictable_size()
        full_num_used = self.full_tokens_per_layer - (
            full_available_size + full_evictable_size
        )
        swa_num_used = self.swa_tokens_per_layer - (
            swa_available_size + swa_evictable_size
        )
        full_token_usage = full_num_used / self.full_tokens_per_layer
        swa_token_usage = swa_num_used / self.swa_tokens_per_layer
        return (
            full_num_used,
            swa_num_used,
            full_token_usage,
            swa_token_usage,
            full_available_size,
            full_evictable_size,
            swa_available_size,
            swa_evictable_size,
        )

1635
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1636
        # Merge the prefill batch into the running batch
1637
1638
1639
1640
1641
1642
1643
1644
        chunked_req_to_exclude = set()
        if self.chunked_req:
            # Move the chunked request out of the batch so that we can merge
            # only finished requests to running_batch.
            chunked_req_to_exclude.add(self.chunked_req)
            self.tree_cache.cache_unfinished_req(self.chunked_req)
            # chunked request keeps its rid but will get a new req_pool_idx
            self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
Lianmin Zheng's avatar
Lianmin Zheng committed
1645
        if self.last_batch and self.last_batch.forward_mode.is_extend():
1646
1647
1648
1649
            if self.last_batch.chunked_req is not None:
                # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
                # We need to discard it.
                chunked_req_to_exclude.add(self.last_batch.chunked_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1650

1651
            # Filter batch
1652
            last_bs = self.last_batch.batch_size()
1653
1654
1655
            self.last_batch.filter_batch(
                chunked_req_to_exclude=list(chunked_req_to_exclude)
            )
1656
            if self.last_batch.batch_size() < last_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1657
                self.running_batch.batch_is_full = False
1658

1659
            # Merge the new batch into the running batch
1660
            if not self.last_batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1661
                if self.running_batch.is_empty():
1662
1663
                    self.running_batch = self.last_batch
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1664
                    # Merge running_batch with prefill batch
1665
                    self.running_batch.merge_batch(self.last_batch)
1666

1667
        new_batch = self.get_new_batch_prefill()
1668

1669
1670
1671
1672
1673
        need_dp_attn_preparation = require_mlp_sync(self.server_args)

        if need_dp_attn_preparation and not self.spec_algorithm.is_none():
            # In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
            # We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
1674
            new_batch = self.prepare_mlp_sync_batch(new_batch)
1675
1676
1677
            need_dp_attn_preparation = new_batch is None

        if new_batch is not None:
1678
1679
1680
1681
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
1682
            if not self.running_batch.is_empty():
1683
                self.running_batch = self.update_running_batch(self.running_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1684
1685
1686
                ret = self.running_batch if not self.running_batch.is_empty() else None
            else:
                ret = None
1687

1688
1689
        # Handle DP attention
        if need_dp_attn_preparation:
1690
            ret = self.prepare_mlp_sync_batch(ret)
1691
1692

        return ret
1693

1694
1695
1696
1697
1698
1699
    def get_num_allocatable_reqs(self, running_bs):
        res = global_server_args_dict["max_micro_batch_size"] - running_bs
        if self.pp_size > 1:
            res = min(res, self.req_to_token_pool.available_size())
        return res

Lianmin Zheng's avatar
Lianmin Zheng committed
1700
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
1701
        # Check if the grammar is ready in the grammar queue
1702
        if self.grammar_queue:
1703
            self.move_ready_grammar_requests()
1704

Lianmin Zheng's avatar
Lianmin Zheng committed
1705
1706
        # Handle the cases where prefill is not allowed
        if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1707
            self.running_batch.batch_is_full or len(self.waiting_queue) == 0
1708
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1709
1710
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1711
        running_bs = len(self.running_batch.reqs)
1712
        # Ignore the check if self.chunked_req is not None.
1713
1714
1715
1716
1717
        # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
        # as the space for the chunked request has just been released.
        # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
        # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
        if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
Lianmin Zheng's avatar
Lianmin Zheng committed
1718
            self.running_batch.batch_is_full = True
1719
1720
            return None

1721
        if self.enable_hierarchical_cache:
1722
            self.tree_cache.check_hicache_events()
1723

1724
        # Get priority queue
1725
        self.policy.calc_priority(self.waiting_queue)
1726

Lianmin Zheng's avatar
Lianmin Zheng committed
1727
        # Prefill policy
1728
        adder = PrefillAdder(
1729
            self.page_size,
1730
            self.tree_cache,
1731
            self.token_to_kv_pool_allocator,
1732
1733
1734
1735
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1736
            running_bs if self.is_mixed_chunk else 0,
1737
1738
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1739
        if self.chunked_req is not None:
1740
1741
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1742

1743
        if self.enable_lora:
Lianmin Zheng's avatar
Lianmin Zheng committed
1744
1745
            lora_set = set([req.lora_path for req in self.running_batch.reqs])

1746
        # Get requests from the waiting queue to a new prefill batch
1747
1748
        for req in self.waiting_queue:
            if (
1749
                self.enable_lora
1750
1751
1752
1753
1754
1755
1756
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1757
                self.running_batch.batch_is_full = True
1758
1759
                break

1760
            if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
Lianmin Zheng's avatar
Lianmin Zheng committed
1761
                self.running_batch.batch_is_full = True
1762
                break
1763

Byron Hsu's avatar
Byron Hsu committed
1764
1765
1766
1767
1768
1769
1770
            if self.disaggregation_mode == DisaggregationMode.PREFILL:
                # In prefill mode, prealloc queue and transfer queue can also take memory,
                # so we need to check if the available size for the actual available size.
                if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
                    self.running_batch.batch_is_full = True
                    break

1771
1772
1773
            if self.enable_hicache_storage:
                self.tree_cache.check_prefetch_progress(req.rid)

1774
1775
            req.init_next_round_input(self.tree_cache)
            res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
1776

1777
1778
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1779
1780
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
Lianmin Zheng's avatar
Lianmin Zheng committed
1781
1782
                        self.running_batch.batch_is_full = len(
                            adder.can_run_list
1783
                        ) > 0 or (not self.running_batch.is_empty())
1784
                    else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1785
                        self.running_batch.batch_is_full = True
1786
1787
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1788
        # Update waiting queue
1789
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1790
1791
        if len(can_run_list) == 0:
            return None
1792
1793
1794
1795

        if self.enable_metrics:
            # only record queue time when enable_metrics is True to avoid overhead
            for req in can_run_list:
1796
                req.queue_time_end = time.perf_counter()
1797

Lianmin Zheng's avatar
Lianmin Zheng committed
1798
1799
1800
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1801

1802
1803
1804
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1805

1806
1807
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1808

1809
        # Print stats
1810
        if self.current_scheduler_metrics_enabled():
1811
            self.log_prefill_stats(adder, can_run_list, running_bs)
1812

Lianmin Zheng's avatar
Lianmin Zheng committed
1813
        # Create a new batch
1814
1815
1816
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
1817
            self.token_to_kv_pool_allocator,
1818
            self.tree_cache,
1819
            self.model_config,
1820
            self.enable_overlap,
1821
            self.spec_algorithm,
1822
            self.server_args.enable_custom_logit_processor,
1823
            chunked_req=self.chunked_req,
1824
        )
1825
1826
        if self.enable_hierarchical_cache:
            # todo (zhiqiang): disable cuda graph execution if hicache loading triggered
1827
1828
1829
            new_batch.hicache_consumer_index = (
                self.tree_cache.ready_to_load_host_cache()
            )
1830

1831
        new_batch.prepare_for_extend()
1832

Lianmin Zheng's avatar
Lianmin Zheng committed
1833
        # Mixed-style chunked prefill
1834
1835
        if (
            self.is_mixed_chunk
Lianmin Zheng's avatar
Lianmin Zheng committed
1836
            and not self.running_batch.is_empty()
1837
1838
1839
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1840
1841
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1842
                self.running_batch.prepare_for_decode()
1843
1844
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1845
1846
1847
            self.running_batch = ScheduleBatch(
                reqs=[], batch_is_full=self.running_batch.batch_is_full
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1848
1849
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1850
1851
1852

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1853
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1854
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1855
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1856

1857
1858
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1859
1860
            batch.batch_is_full = False
            return batch
1861

Lianmin Zheng's avatar
Lianmin Zheng committed
1862
        # Check if decode out of memory
1863
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1864
            TEST_RETRACT and batch.batch_size() > 10
1865
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1866
1867
            old_ratio = self.new_token_ratio

1868
            retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
1869
            num_retracted_reqs = len(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1870
            self.new_token_ratio = new_token_ratio
1871

Lianmin Zheng's avatar
Lianmin Zheng committed
1872
            logger.info(
1873
                "KV cache pool is full. Retract requests. "
1874
                f"#retracted_reqs: {num_retracted_reqs}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
1875
1876
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
1877

1878
            self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
1879
            self.total_retracted_reqs += num_retracted_reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1880
1881
        else:
            self.new_token_ratio = max(
1882
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1883
1884
1885
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1886
        if batch.batch_size() < initial_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1887
            batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1888
1889

        # Update batch tensors
1890
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1891
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1892

1893
1894
1895
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1896
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1897
1898
        self.forward_ct += 1

1899
1900
        # Whether to run the profiler
        self._profile_batch_predicate(batch)
1901
1902
1903
1904
        if self.forward_sleep_time is not None:
            logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
            time.sleep(self.forward_sleep_time)

1905
        # Run forward
1906
        if self.is_generation:
1907
1908
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
1909
1910
1911
1912
1913

                # update the consumer index of hicache to the running batch
                self.tp_worker.set_hicache_consumer(
                    model_worker_batch.hicache_consumer_index
                )
1914
                if self.pp_group.is_last_rank:
1915
                    logits_output, next_token_ids, can_run_cuda_graph = (
1916
1917
1918
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
                else:
1919
                    pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
1920
1921
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
1922
                bid = model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
1923
            else:
1924
1925
1926
                (
                    logits_output,
                    next_token_ids,
1927
                    bid,
1928
                    num_accepted_tokens,
1929
                    can_run_cuda_graph,
1930
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
1931
1932
1933
                bs = batch.batch_size()
                self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
                self.spec_num_total_forward_ct += bs
1934
                self.num_generated_tokens += num_accepted_tokens
1935
1936
1937

            if self.pp_group.is_last_rank:
                batch.output_ids = next_token_ids
1938

1939
1940
1941
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
1942
            if batch.return_logprob or self.spec_algorithm.is_eagle():
1943
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
1944
1945
1946
            else:
                extend_input_len_per_req = None
            if batch.return_logprob:
1947
1948
1949
1950
1951
1952
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_logprob_start_len_per_req = None

1953
            ret = GenerationBatchResult(
1954
1955
1956
1957
1958
1959
1960
                logits_output=logits_output if self.pp_group.is_last_rank else None,
                pp_hidden_states_proxy_tensors=(
                    pp_hidden_states_proxy_tensors
                    if not self.pp_group.is_last_rank
                    else None
                ),
                next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
1961
1962
                extend_input_len_per_req=extend_input_len_per_req,
                extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1963
                bid=bid,
1964
                can_run_cuda_graph=can_run_cuda_graph,
1965
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1966
1967
1968
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1969
1970
1971
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1972
        return ret
Chayenne's avatar
Chayenne committed
1973

1974
1975
1976
1977
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
1978
        launch_done: Optional[threading.Event] = None,
1979
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1980
        if batch.forward_mode.is_decode():
1981
            self.process_batch_result_decode(batch, result, launch_done)
1982
        elif batch.forward_mode.is_extend():
1983
            self.process_batch_result_prefill(batch, result, launch_done)
1984
1985
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1986
                self.tp_worker.resolve_last_batch_result(launch_done)
1987
                self.set_next_batch_sampling_info_done(batch)
1988
        elif batch.forward_mode.is_dummy_first():
1989
            self.set_next_batch_sampling_info_done(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1990

1991
1992
1993
1994
1995
1996
1997
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

1998
1999
    def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
        return self.prepare_mlp_sync_batch_raw(
2000
2001
2002
            local_batch,
            dp_size=self.server_args.dp_size,
            attn_tp_size=self.attn_tp_size,
2003
            tp_group=self.tp_group,
2004
2005
2006
2007
            get_idle_batch=self.get_idle_batch,
            disable_cuda_graph=self.server_args.disable_cuda_graph,
            spec_algorithm=self.spec_algorithm,
            speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
2008
2009
2010
            enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
            enable_deepep_moe=self.server_args.enable_deepep_moe,
            deepep_mode=DeepEPMode[self.server_args.deepep_mode],
2011
            require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
2012
            disable_overlap_schedule=self.server_args.disable_overlap_schedule,
2013
2014
2015
        )

    @staticmethod
2016
    def prepare_mlp_sync_batch_raw(
2017
2018
2019
        local_batch: ScheduleBatch,
        dp_size,
        attn_tp_size: int,
2020
        tp_group,
2021
2022
2023
2024
        get_idle_batch,
        disable_cuda_graph: bool,
        spec_algorithm,
        speculative_num_draft_tokens,
2025
2026
2027
        enable_two_batch_overlap: bool,
        enable_deepep_moe: bool,
        deepep_mode: DeepEPMode,
2028
        require_mlp_tp_gather: bool,
2029
        disable_overlap_schedule: bool,
2030
    ):
2031
2032
2033
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
2034
            num_tokens_for_logprob = 0
2035
2036
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
2037
            num_tokens_for_logprob = num_tokens
2038
2039
        else:
            num_tokens = local_batch.extend_num_tokens
2040
            num_tokens_for_logprob = sum(
Lianmin Zheng's avatar
Lianmin Zheng committed
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
                [
                    # We should have at least 1 token for sample in every case.
                    max(extend_len - logprob_start_len, 1)
                    for logprob_start_len, extend_len in zip(
                        local_batch.extend_logprob_start_lens, local_batch.extend_lens
                    )
                ]
            )

        if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
            can_cuda_graph = 1
        else:
            can_cuda_graph = 0

        is_extend_in_batch = (
            local_batch.forward_mode.is_extend() if local_batch else False
        )
2058
2059

        tbo_preparer = TboDPAttentionPreparer()
2060
2061
2062
2063
2064
2065
        if disable_overlap_schedule:
            group = tp_group.device_group
            device = tp_group.device
        else:
            group = tp_group.cpu_group
            device = "cpu"
2066

Lianmin Zheng's avatar
Lianmin Zheng committed
2067
2068
2069
2070
        local_info = torch.tensor(
            [
                num_tokens,
                can_cuda_graph,
2071
                num_tokens_for_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
2072
                is_extend_in_batch,
2073
2074
2075
2076
2077
2078
                *tbo_preparer.prepare_all_gather(
                    local_batch,
                    deepep_mode,
                    enable_deepep_moe,
                    enable_two_batch_overlap,
                ),
Lianmin Zheng's avatar
Lianmin Zheng committed
2079
2080
            ],
            dtype=torch.int64,
2081
            device=device,
Lianmin Zheng's avatar
Lianmin Zheng committed
2082
2083
        )
        global_info = torch.empty(
2084
            (dp_size, attn_tp_size, 6),
Lianmin Zheng's avatar
Lianmin Zheng committed
2085
            dtype=torch.int64,
2086
            device=device,
Lianmin Zheng's avatar
Lianmin Zheng committed
2087
        )
2088
        torch.distributed.all_gather_into_tensor(
Lianmin Zheng's avatar
Lianmin Zheng committed
2089
2090
            global_info.flatten(),
            local_info,
2091
            group=group,
2092
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
2093
2094
2095
2096
        global_num_tokens = global_info[:, 0, 0].tolist()
        can_cuda_graph = min(global_info[:, 0, 1].tolist())
        global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
        is_extend_in_batch = global_info[:, 0, 3].tolist()
2097

2098
2099
2100
2101
        tbo_split_seq_index, global_forward_mode = tbo_preparer.compute_output(
            global_info[:, :, 4:6]
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
2102
        if local_batch is None and max(global_num_tokens) > 0:
2103
            local_batch = get_idle_batch()
2104
2105

        if local_batch is not None:
2106
            # TODO: handle the case when moe_dense_tp_size != 1
2107
            if not require_mlp_tp_gather:
2108
2109
2110
2111
2112
2113
2114
                local_batch.global_num_tokens = [num_tokens]
                local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
            else:
                local_batch.global_num_tokens = global_num_tokens
                local_batch.global_num_tokens_for_logprob = (
                    global_num_tokens_for_logprob
                )
2115
            local_batch.is_extend_in_batch = any(is_extend_in_batch)
2116
2117
            local_batch.tbo_split_seq_index = tbo_split_seq_index
            local_batch.global_forward_mode = global_forward_mode
2118

2119
            # Check forward mode for cuda graph
2120
            if not disable_cuda_graph:
Lianmin Zheng's avatar
Lianmin Zheng committed
2121
                local_batch.can_run_dp_cuda_graph = can_cuda_graph
2122

2123
        return local_batch
2124
2125
2126
2127
2128

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
2129
            self.token_to_kv_pool_allocator,
2130
2131
2132
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
2133
            self.spec_algorithm,
2134
            self.server_args.enable_custom_logit_processor,
2135
2136
2137
2138
        )
        idle_batch.prepare_for_idle()
        return idle_batch

2139
2140
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
2141

2142
        num_ready_reqs = 0
2143
        num_timeout_reqs = 0
2144
2145
        for req in self.grammar_queue:
            try:
2146
2147
2148
                if req.finished():  # It is aborted by AbortReq
                    num_ready_reqs += 1
                    continue
2149
                req.grammar = req.grammar.result(timeout=0.03)
2150
2151
2152
2153
2154
                self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
                if req.grammar is INVALID_GRAMMAR_OBJ:
                    req.set_finish_with_abort(
                        f"Invalid grammar request: {req.grammar_key=}"
                    )
2155
2156
                num_ready_reqs += 1
            except futures._base.TimeoutError:
2157
                req.grammar_wait_ct += 1
2158
2159
                # NOTE(lianmin): this timeout is the waiting time of the above line. It is
                # not the waiting time from it enters the grammar queue.
2160
                if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
2161
                    num_timeout_reqs = 1
2162
2163
                break

2164
        if self.server_args.enable_dp_attention:
2165
2166
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
2167
        else:
2168
2169
2170
2171
2172
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
2173
            tensor = torch.tensor([num_ready_reqs, num_timeout_reqs], dtype=torch.int32)
2174
2175
2176
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
2177
            num_ready_reqs_max, num_timeout_reqs_max = tensor.tolist()
2178

2179
            for i in range(num_ready_reqs, num_ready_reqs_max):
2180
                req = self.grammar_queue[i]
2181
2182
                if req.finished():  # It is aborted by AbortReq
                    continue
2183
                req.grammar = req.grammar.result()
2184
2185
2186
2187
2188
2189
2190
2191
                self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
                if req.grammar is INVALID_GRAMMAR_OBJ:
                    req.set_finish_with_abort(
                        f"Invalid grammar request: {req.grammar_key=}"
                    )
        else:
            num_ready_reqs_max = num_ready_reqs
            num_timeout_reqs_max = num_timeout_reqs
2192

2193
2194
2195
2196
2197
2198
2199
        for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
            req = self.grammar_queue[i]
            req.grammar.cancel()
            error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
            req.set_finish_with_abort(error_msg)
            self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
        num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
2200

2201
        self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
2202
2203
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

2204
2205
2206
2207
2208
2209
2210
    def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
        if batch.next_batch_sampling_info:
            if batch.next_batch_sampling_info.grammars is not None:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
                self.current_stream.synchronize()
            batch.next_batch_sampling_info.sampling_info_done.set()

2211
2212
2213
    def watchdog_thread(self):
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
        self.watchdog_last_forward_ct = 0
2214
        self.watchdog_last_time = time.perf_counter()
2215
2216

        while True:
2217
            current = time.perf_counter()
2218
2219
2220
2221
2222
2223
2224
2225
2226
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if current > self.watchdog_last_time + self.watchdog_timeout:
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)

Lianmin Zheng's avatar
Lianmin Zheng committed
2227
2228
        if not disable_request_logging():
            # Print batch size and memory pool info to check whether there are de-sync issues.
Hanming Lu's avatar
Hanming Lu committed
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
            if self.is_hybrid:
                (
                    _,
                    _,
                    _,
                    _,
                    full_available_size,
                    full_evictable_size,
                    swa_available_size,
                    swa_evictable_size,
                ) = self._get_swa_token_info()
                info_msg = (
                    f"{full_available_size=}, "
                    f"{full_evictable_size=}, "
                    f"{swa_available_size=}, "
                    f"{swa_evictable_size=}, "
                )
            else:
                _, _, available_size, evictable_size = self._get_token_info()
                info_msg = f"{available_size=}, " f"{evictable_size=}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
2249
2250
2251
            logger.error(
                f"{self.cur_batch.batch_size()=}, "
                f"{self.cur_batch.reqs=}, "
Hanming Lu's avatar
Hanming Lu committed
2252
                f"{info_msg}"
Lianmin Zheng's avatar
Lianmin Zheng committed
2253
2254
            )

2255
        pyspy_dump_schedulers()
Lianmin Zheng's avatar
Lianmin Zheng committed
2256
        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
2257
2258
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
2259
2260

        # Wait for some time so that the parent process can print the error.
2261
2262
2263
        time.sleep(5)
        self.parent_process.send_signal(signal.SIGQUIT)

2264
2265
2266
    def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
        success = self.flush_cache()
        return FlushCacheReqOutput(success=success)
2267

2268
    def flush_cache(self):
2269
        """Flush the memory pool and cache."""
2270
2271
2272
2273
2274
        if (
            len(self.waiting_queue) == 0
            and self.running_batch.is_empty()
            and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
        ):
2275
2276
            self.cur_batch = None
            self.last_batch = None
2277
            self.tree_cache.reset()
2278
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
2279
                self.grammar_backend.reset()
2280
            self.req_to_token_pool.clear()
2281
            self.token_to_kv_pool_allocator.clear()
2282
2283
2284

            if not self.spec_algorithm.is_none():
                self.draft_worker.model_runner.req_to_token_pool.clear()
2285
                self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
2286
2287
2288
2289
2290

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
2291
2292
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
2293
2294
2295
2296
2297
2298
2299
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
2300
                f"#running-req: {len(self.running_batch.reqs)}"
2301
2302
2303
2304
            )
            if_success = False
        return if_success

Liangsheng Yin's avatar
Liangsheng Yin committed
2305
2306
    def get_load(self):
        # TODO(lsyin): use dynamically maintained num_waiting_tokens
Hanming Lu's avatar
Hanming Lu committed
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
        if self.is_hybrid:
            load_full = (
                self.full_tokens_per_layer
                - self.token_to_kv_pool_allocator.full_available_size()
                - self.tree_cache.full_evictable_size()
            )
            load_swa = (
                self.swa_tokens_per_layer
                - self.token_to_kv_pool_allocator.swa_available_size()
                - self.tree_cache.swa_evictable_size()
            )
            load = max(load_full, load_swa)
        else:
            load = (
                self.max_total_num_tokens
                - self.token_to_kv_pool_allocator.available_size()
                - self.tree_cache.evictable_size()
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
        load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            load += sum(
                len(req.origin_input_ids)
                for req in self.disagg_prefill_bootstrap_queue.queue
            )
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            load += sum(
                len(req.req.origin_input_ids)
                for req in self.disagg_decode_prealloc_queue.queue
            )

        return load

2339
2340
2341
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
2342
2343
2344
2345
2346
2347
2348
2349
2350
        ret["memory_usage"] = {
            "weight": round(
                self.tp_worker.worker.model_runner.weight_load_mem_usage, 2
            ),
            "kvcache": round(
                self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
            ),
            "token_capacity": int(self.max_total_num_tokens),
        }
2351
2352
2353
2354
2355
2356

        if not _is_cpu:
            ret["memory_usage"]["cuda_graph"] = round(
                self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
            )

2357
2358
2359
2360
2361
2362
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )
        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
Liangsheng Yin's avatar
Liangsheng Yin committed
2363
2364
2365
2366

        ret["load"] = self.get_load()

        return GetInternalStateReqOutput(internal_state=ret)
2367
2368
2369
2370
2371

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
2372
                "max_micro_batch_size",
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
2383
2384
2385
2386
2387
2388
2389
2390
            elif k == "max_micro_batch_size" and (
                v > self.max_running_requests // self.pp_size or v < 1
            ):
                logging.warning(
                    f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
                )
                if_success = False
                break
2391
2392
2393
2394
2395
2396
2397
2398
2399
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
2400
            logger.info(f"Global server args updated! {global_server_args_dict=}")
2401
2402
2403
2404
2405
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
    def handle_rpc_request(self, recv_req: RpcReqInput):
        # Handle RPC requests
        logger.info(
            f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
        )

        success = True
        exec = None
        try:
            func = getattr(self, recv_req.method)
            func(recv_req.parameters)
        except Exception as e:
            success = False
            exec = e
            logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")

        barrier()
        return RpcReqOutput(success, "" if not exec else str(exec))

    def save_remote_model(self, params):
        url = params["url"]

2428
        worker = self.tp_worker.worker
2429
2430
2431
2432

        worker.model_runner.save_remote_model(url)

    def save_sharded_model(self, params):
2433
        worker = self.tp_worker.worker
2434
2435
2436
2437
2438
2439
2440

        worker.model_runner.save_sharded_model(
            path=params["path"],
            pattern=params["pattern"],
            max_size=params["max_size"],
        )

2441
2442
    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
Lianmin Zheng's avatar
Lianmin Zheng committed
2443
        to_del = []
2444
        for i, req in enumerate(self.waiting_queue):
2445
            if recv_req.abort_all or req.rid.startswith(recv_req.rid):
Lianmin Zheng's avatar
Lianmin Zheng committed
2446
                to_del.append(i)
2447

Lianmin Zheng's avatar
Lianmin Zheng committed
2448
        # Sort in reverse order to avoid index issues when deleting
Lianmin Zheng's avatar
Lianmin Zheng committed
2449
        for i in reversed(to_del):
2450
2451
2452
            # Abort method 1: directly pop from the queue
            # This only works for requests that have not started anything.
            # We still need to send something back to TokenizerManager to clean up the state.
Lianmin Zheng's avatar
Lianmin Zheng committed
2453
            req = self.waiting_queue.pop(i)
Lianmin Zheng's avatar
Lianmin Zheng committed
2454
            self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
2455
            logger.debug(f"Abort queued request. {req.rid=}")
2456

2457
2458
2459
2460
2461
        # Delete the requests in the grammar queue
        for req in self.grammar_queue:
            # Abort method 2: call `set_finish_with_abort`
            # The request will still run one prefill forward pass.
            # In this case, we change the input_ids to be only one token to make this prefill cheap.
2462
            if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2463
                logger.debug(f"Abort grammar queue request. {req.rid=}")
2464
2465
                if req.grammar:
                    req.grammar.cancel()
2466
2467
                req.set_finish_with_abort("Aborted by AbortReq.")

2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
        # Delete requests not in the waiting queue when PD disaggregation is enabled
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            # Abort requests that have not yet been bootstrapped
            for i, req in enumerate(self.disagg_prefill_bootstrap_queue.queue):
                logger.debug(f"Abort bootstrap queue request. {req.rid=}")
                if recv_req.abort_all or req.rid.startswith(recv_req.rid):
                    if hasattr(req.disagg_kv_sender, "abort"):
                        req.disagg_kv_sender.abort()

            # Abort in-flight requests
            for i, req in enumerate(self.disagg_prefill_inflight_queue):
                logger.debug(f"Abort inflight queue request. {req.rid=}")
                if recv_req.abort_all or req.rid.startswith(recv_req.rid):
                    if hasattr(req.disagg_kv_sender, "abort"):
                        req.disagg_kv_sender.abort()

        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            # Abort requests that have not yet finished preallocation
            for i, decode_req in enumerate(self.disagg_decode_prealloc_queue.queue):
                logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
                if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
                    if hasattr(decode_req.kv_receiver, "abort"):
                        decode_req.kv_receiver.abort()

            # Abort requests waiting for kvcache to release tree cache
            for i, decode_req in enumerate(self.disagg_decode_transfer_queue.queue):
                logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
                if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
                    if hasattr(decode_req.kv_receiver, "abort"):
                        decode_req.kv_receiver.abort()

2499
        # Delete requests in the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
2500
2501
2502
2503
2504
2505
        if self.cur_batch is self.running_batch or self.cur_batch is None:
            reqs = self.running_batch.reqs
        else:
            reqs = self.running_batch.reqs + self.cur_batch.reqs

        for req in reqs:
2506
2507
2508
            if not req.finished() and (
                recv_req.abort_all or req.rid.startswith(recv_req.rid)
            ):
2509
2510
2511
                # Abort method 3: set `to_abort=True`
                # The request will still run one decode forward pass.
                # Then we reuse all existing code to clean up the KV cache allocation.
Lianmin Zheng's avatar
Lianmin Zheng committed
2512
2513
                logger.debug(f"Abort running request. {req.rid=}")
                req.to_abort = True
2514

2515
2516
2517
    def _pause_engine(self) -> Tuple[List[Req], int]:
        raise NotImplementedError()

Chayenne's avatar
Chayenne committed
2518
2519
2520
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
2521
        if success:
Stefan He's avatar
Stefan He committed
2522
2523
            flush_cache_success = self.flush_cache()
            assert flush_cache_success, "Cache flush failed after updating weights"
2524
2525
        else:
            logger.error(message)
2526
        return UpdateWeightFromDiskReqOutput(success, message, 0)
2527

2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
    def load_lora_adapter(
        self, recv_req: LoadLoRAAdapterReqInput
    ) -> LoadLoRAAdapterReqOutput:
        """In-place loading a new lora adapter from disk or huggingface."""

        result = self.tp_worker.load_lora_adapter(recv_req)
        return result

    def unload_lora_adapter(
        self, recv_req: UnloadLoRAAdapterReqInput
    ) -> UnloadLoRAAdapterReqOutput:
        """Unload the lora adapter."""

        result = self.tp_worker.unload_lora_adapter(recv_req)
        return result

2544
2545
2546
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
2547
        return InitWeightsUpdateGroupReqOutput(success, message)
2548
2549

    def update_weights_from_distributed(
2550
2551
2552
        self,
        recv_req: UpdateWeightsFromDistributedReqInput,
    ) -> Tuple[bool, str]:
2553
2554
2555
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
2556
2557
2558
            if recv_req.flush_cache:
                flush_cache_success = self.flush_cache()
                assert flush_cache_success, "Cache flush failed after updating weights"
2559
2560
        else:
            logger.error(message)
2561
        return UpdateWeightsFromDistributedReqOutput(success, message)
2562

2563
2564
2565
2566
2567
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        """Update the online model parameter from tensors."""
        success, message = self.tp_worker.update_weights_from_tensor(recv_req)
        # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
        if success:
2568
            if recv_req.flush_cache:
Stefan He's avatar
Stefan He committed
2569
2570
                flush_cache_success = self.flush_cache()
                assert flush_cache_success, "Cache flush failed after updating weights"
2571
2572
        else:
            logger.error(message)
2573
        barrier(group=self.tp_cpu_group)
2574
        return UpdateWeightsFromTensorReqOutput(success, message)
2575

2576
2577
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
2578
        return GetWeightsByNameReqOutput(parameter)
2579

2580
    def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
2581
2582
        tags = recv_req.tags

2583
        if tags is None or len(tags) == 0:
2584
2585
2586
2587
2588
2589
2590
2591
2592
2593
            tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]

        if GPU_MEMORY_TYPE_KV_CACHE in tags:
            self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
            self.flush_cache()

        if GPU_MEMORY_TYPE_WEIGHTS in tags:
            self.stashed_model_static_state = _export_static_state(
                self.tp_worker.worker.model_runner.model
            )
2594
            torch.distributed.barrier(self.tp_cpu_group)
2595
2596
            self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)

2597
        return ReleaseMemoryOccupationReqOutput()
2598

2599
    def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
2600
        tags = recv_req.tags
2601

2602
2603
2604
2605
2606
        if tags is None or len(tags) == 0:
            tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]

        if GPU_MEMORY_TYPE_WEIGHTS in tags:
            self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
2607
            torch.distributed.barrier(self.tp_cpu_group)
2608
2609
2610
2611
2612
2613
2614
2615
2616
            _import_static_state(
                self.tp_worker.worker.model_runner.model,
                self.stashed_model_static_state,
            )
            del self.stashed_model_static_state

        if GPU_MEMORY_TYPE_KV_CACHE in tags:
            self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)

2617
2618
        return ResumeMemoryOccupationReqOutput()

2619
2620
2621
2622
2623
2624
2625
    def slow_down(self, recv_req: SlowDownReqInput):
        t = recv_req.forward_sleep_time
        if t is not None and t <= 0:
            t = None
        self.forward_sleep_time = t
        return SlowDownReqOutput()

2626
    def profile(self, recv_req: ProfileReq):
2627
        if recv_req.type == ProfileReqType.START_PROFILE:
2628
            if recv_req.profile_by_stage or recv_req.start_step:
2629
2630
                return self.init_profile(
                    recv_req.output_dir,
2631
                    recv_req.start_step,
2632
2633
2634
2635
2636
                    recv_req.num_steps,
                    recv_req.activities,
                    recv_req.with_stack,
                    recv_req.record_shapes,
                    recv_req.profile_by_stage,
2637
                    recv_req.profile_id,
2638
2639
2640
2641
                )
            else:
                self.init_profile(
                    recv_req.output_dir,
2642
                    recv_req.start_step,
2643
2644
2645
2646
2647
                    recv_req.num_steps,
                    recv_req.activities,
                    recv_req.with_stack,
                    recv_req.record_shapes,
                    recv_req.profile_by_stage,
2648
                    recv_req.profile_id,
2649
2650
                )
                return self.start_profile(True)
2651
        else:
2652
2653
            return self.stop_profile()

2654
    def init_profile(
2655
2656
        self,
        output_dir: Optional[str],
2657
        start_step: Optional[int],
2658
2659
        num_steps: Optional[int],
        activities: Optional[List[str]],
2660
2661
        with_stack: Optional[bool],
        record_shapes: Optional[bool],
2662
        profile_by_stage: bool,
2663
        profile_id: str,
2664
2665
    ) -> ProfileReqOutput:
        if self.profile_in_progress:
2666
2667
2668
2669
2670
            return ProfileReqOutput(
                success=False,
                message="Profiling is already in progress. Call /stop_profile first.",
            )

2671
2672
        self.profile_by_stage = profile_by_stage

2673
2674
2675
2676
2677
2678
        if output_dir is None:
            output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
        if activities is None:
            activities = ["CPU", "GPU"]

        self.torch_profiler_output_dir = output_dir
2679
2680
        self.torch_profiler_with_stack = with_stack
        self.torch_profiler_record_shapes = record_shapes
2681
        self.profiler_activities = activities
2682
        self.profile_id = profile_id
2683

2684
2685
2686
        if start_step:
            self.profiler_start_forward_ct = max(start_step, self.forward_ct + 1)

2687
2688
2689
2690
2691
2692
2693
        if num_steps:
            self.profile_steps = num_steps
            if self.profile_by_stage:
                self.profiler_target_prefill_ct = num_steps
                self.profiler_target_decode_ct = num_steps
                self.profiler_prefill_ct = 0
                self.profiler_decode_ct = 0
2694
2695
2696
2697
            elif start_step:
                self.profiler_target_forward_ct = (
                    self.profiler_start_forward_ct + num_steps
                )
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
            else:
                self.profiler_target_forward_ct = self.forward_ct + num_steps
            # The caller will be notified when reaching profiler_target_forward_ct
        else:
            self.profiler_target_forward_ct = None

        return ProfileReqOutput(success=True, message="Succeeded")

    def start_profile(
        self, stage: Optional[ForwardMode] = None
    ) -> ProfileReqOutput | None:
        stage_str = f" for {stage.__str__()}" if stage else ""
2710
        logger.info(
2711
            f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
2712
2713
        )

2714
2715
2716
2717
        activities = self.profiler_activities
        with_stack = self.torch_profiler_with_stack
        record_shapes = self.torch_profiler_record_shapes

2718
2719
2720
2721
2722
2723
2724
2725
        activity_map = {
            "CPU": torch.profiler.ProfilerActivity.CPU,
            "GPU": torch.profiler.ProfilerActivity.CUDA,
        }
        torchprof_activities = [
            activity_map[a] for a in activities if a in activity_map
        ]

2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751
2752
2753
2754
2755
        if "RPD" in activities:
            from rpdTracerControl import rpdTracerControl

            rpdTracerControl.skipCreate()

            self.rpd_profile_path = os.path.join(
                self.torch_profiler_output_dir,
                "rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
            )

            if self.tp_rank == 0:
                import sqlite3

                from rocpd.schema import RocpdSchema

                if os.path.exists("trace.rpd"):
                    os.unlink("trace.rpd")
                schema = RocpdSchema()
                connection = sqlite3.connect("trace.rpd")
                schema.writeSchema(connection)
                connection.commit()
                del connection
            torch.distributed.barrier(self.tp_cpu_group)

            self.rpd_profiler = rpdTracerControl()
            self.rpd_profiler.setPythonTrace(True)
            self.rpd_profiler.start()
            self.rpd_profiler.rangePush("", "rpd profile range", "")
            self.profile_in_progress = True
        elif torchprof_activities:
2756
2757
            self.torch_profiler = torch.profiler.profile(
                activities=torchprof_activities,
2758
2759
                with_stack=with_stack if with_stack is not None else True,
                record_shapes=record_shapes if record_shapes is not None else False,
2760
2761
            )
            self.torch_profiler.start()
2762
            self.profile_in_progress = True
2763
2764
2765

        if "MEM" in activities:
            torch.cuda.memory._record_memory_history(max_entries=100000)
2766
            self.profile_in_progress = True
2767

2768
2769
        if "CUDA_PROFILER" in activities:
            torch.cuda.cudart().cudaProfilerStart()
2770
            self.profile_in_progress = True
2771

2772
        return ProfileReqOutput(success=True, message="Succeeded")
2773

2774
2775
2776
2777
    def stop_profile(
        self, stage: Optional[ForwardMode] = None
    ) -> ProfileReqOutput | None:
        if not self.profile_in_progress:
2778
2779
2780
2781
            return ProfileReqOutput(
                success=False,
                message="Profiling is not in progress. Call /start_profile first.",
            )
2782

2783
2784
2785
        if not Path(self.torch_profiler_output_dir).exists():
            Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)

2786
2787
        stage_suffix = f"-{stage.__str__()}" if stage else ""
        logger.info("Stop profiling" + stage_suffix + "...")
2788
2789
2790
2791
2792
        if self.torch_profiler is not None:
            self.torch_profiler.stop()
            self.torch_profiler.export_chrome_trace(
                os.path.join(
                    self.torch_profiler_output_dir,
2793
                    self.profile_id
2794
2795
2796
                    + f"-TP-{self.tp_rank}"
                    + stage_suffix
                    + ".trace.json.gz",
2797
2798
                )
            )
2799
2800
2801
2802
2803
2804
            torch.distributed.barrier(self.tp_cpu_group)

        if self.rpd_profiler is not None:
            self.rpd_profiler.rangePop()
            self.rpd_profiler.stop()
            self.rpd_profiler.flush()
2805

2806
2807
2808
2809
2810
2811
2812
2813
2814
            torch.distributed.barrier(self.tp_cpu_group)
            if self.tp_rank == 0:
                from sglang.srt.utils import rpd_to_chrome_trace

                rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
            self.rpd_profiler = None
            self.rpd_profiler_path = None

        if self.profiler_activities is not None and "MEM" in self.profiler_activities:
2815
            memory_profile_path = os.path.join(
2816
                self.torch_profiler_output_dir,
2817
2818
2819
2820
                str(time.time())
                + f"-TP-{self.tp_rank}-memory"
                + stage_suffix
                + ".pickle",
2821
2822
2823
2824
            )
            torch.cuda.memory._dump_snapshot(memory_profile_path)
            torch.cuda.memory._record_memory_history(enabled=None)

2825
2826
2827
        if "CUDA_PROFILER" in self.profiler_activities:
            torch.cuda.cudart().cudaProfilerStop()

2828
2829
2830
        logger.info(
            "Profiling done. Traces are saved to: %s",
            self.torch_profiler_output_dir,
2831
        )
2832
        self.torch_profiler = None
2833
        self.profile_in_progress = False
2834
        self.profiler_start_forward_ct = None
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856

        return ProfileReqOutput(success=True, message="Succeeded.")

    def _profile_batch_predicate(self, batch):
        if self.profile_by_stage:
            if batch.forward_mode.is_prefill():
                if self.profiler_prefill_ct == 0:
                    self.start_profile(batch.forward_mode)
                self.profiler_prefill_ct += 1
                if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
                    if self.profile_in_progress:
                        self.stop_profile(stage=ForwardMode.EXTEND)
            elif batch.forward_mode.is_decode():
                if self.profiler_decode_ct == 0:
                    if self.profile_in_progress:
                        # force trace flush
                        self.stop_profile(ForwardMode.EXTEND)
                    self.start_profile(batch.forward_mode)
                self.profiler_decode_ct += 1
                if self.profiler_decode_ct > self.profiler_target_decode_ct:
                    if self.profile_in_progress:
                        self.stop_profile(stage=ForwardMode.DECODE)
2857
2858
            elif batch.forward_mode.is_idle():
                pass
2859
            else:
2860
                raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}")
2861
2862
2863
2864
2865
2866
2867
        else:
            # Check profiler
            if (
                self.profiler_target_forward_ct
                and self.profiler_target_forward_ct <= self.forward_ct
            ):
                self.stop_profile()
2868
2869
2870
2871
2872
            if (
                self.profiler_start_forward_ct
                and self.profiler_start_forward_ct == self.forward_ct
            ):
                self.start_profile()
2873

2874
2875
    def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
        if recv_req == ExpertDistributionReq.START_RECORD:
2876
            get_global_expert_distribution_recorder().start_record()
2877
        elif recv_req == ExpertDistributionReq.STOP_RECORD:
2878
            get_global_expert_distribution_recorder().stop_record()
2879
        elif recv_req == ExpertDistributionReq.DUMP_RECORD:
2880
            get_global_expert_distribution_recorder().dump_record()
2881
2882
        else:
            raise ValueError("Unrecognized ExpertDistributionReq value")
2883
        return ExpertDistributionReqOutput()
2884

2885
    def open_session(self, recv_req: OpenSessionReqInput):
2886
2887
2888
2889
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
2890
            return OpenSessionReqOutput(session_id, False)
2891
        elif session_id is None:
2892
            logger.warning("session id is None, cannot open.")
2893
            return OpenSessionReqOutput(session_id, False)
2894
2895
2896
2897
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
2898
            return OpenSessionReqOutput(session_id, True)
2899
2900
2901
2902
2903
2904
2905
2906
2907

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

2908
2909
    def get_print_prefix(self):
        prefix = ""
2910
2911
        if self.attn_dp_rank is not None:
            prefix += f" DP{self.attn_dp_rank}"
2912
2913
2914
2915
2916
2917
        if self.server_args.tp_size > 1:
            prefix += f" TP{self.tp_rank}"
        if self.pp_size > 1:
            prefix += f" PP{self.pp_rank}"
        return prefix

2918
2919
2920
2921
2922
2923
2924
    def _publish_kv_events(self):
        if self.enable_kv_cache_events:
            events = self.tree_cache.take_events()
            if events:
                batch = KVEventBatch(ts=time.time(), events=events)
                self.kv_event_publisher.publish(batch)

2925

2926
2927
2928
2929
def is_health_check_generate_req(recv_req):
    return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")


2930
2931
2932
2933
def is_work_request(recv_req):
    return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))


2934
2935
2936
2937
2938
2939
2940
2941
2942
2943
2944
2945
2946
2947
def _export_static_state(model):
    return dict(
        buffers=[
            (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
        ]
    )


def _import_static_state(model, static_params):
    self_named_buffers = dict(model.named_buffers())
    for name, tensor in static_params["buffers"]:
        self_named_buffers[name][...] = tensor


2948
2949
2950
2951
2952
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
2953
    pp_rank: int,
2954
    dp_rank: Optional[int],
2955
    pipe_writer,
2956
):
2957
    # Generate the prefix
2958
2959
2960
2961
2962
2963
2964
    prefix = ""
    if dp_rank is not None:
        prefix += f" DP{dp_rank}"
    if server_args.tp_size > 1:
        prefix += f" TP{tp_rank}"
    if server_args.pp_size > 1:
        prefix += f" PP{pp_rank}"
2965

2966
    # Config the process
2967
    setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
2968
    faulthandler.enable()
2969
    kill_itself_when_parent_died()
2970
    parent_process = psutil.Process().parent()
2971

2972
2973
2974
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
2975

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
2976
    # Configure the logger
2977
    configure_logger(server_args, prefix=prefix)
2978
    suppress_other_loggers()
2979

2980
    # Set cpu affinity to this gpu process
2981
2982
2983
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

2984
    # Create a scheduler and run the event loop
2985
    try:
2986
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
2987
        pipe_writer.send(
Mick's avatar
Mick committed
2988
2989
2990
2991
2992
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2993
        )
Byron Hsu's avatar
Byron Hsu committed
2994

2995
        disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
Byron Hsu's avatar
Byron Hsu committed
2996
        if disaggregation_mode == DisaggregationMode.NULL:
2997
2998
2999
            if server_args.pp_size > 1:
                scheduler.event_loop_pp()
            elif scheduler.enable_overlap:
Byron Hsu's avatar
Byron Hsu committed
3000
3001
3002
3003
                scheduler.event_loop_overlap()
            else:
                scheduler.event_loop_normal()
        elif disaggregation_mode == DisaggregationMode.PREFILL:
3004
3005
3006
3007
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_prefill()
            else:
                scheduler.event_loop_normal_disagg_prefill()
3008

Byron Hsu's avatar
Byron Hsu committed
3009
        elif disaggregation_mode == DisaggregationMode.DECODE:
3010
3011
3012
3013
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_decode()
            else:
                scheduler.event_loop_normal_disagg_decode()
Byron Hsu's avatar
Byron Hsu committed
3014

3015
    except Exception:
3016
3017
3018
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)