scheduler.py 110 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
import time
23
from collections import defaultdict, deque
Lianmin Zheng's avatar
Lianmin Zheng committed
24
from concurrent import futures
25
from dataclasses import dataclass
26
from pathlib import Path
27
from types import SimpleNamespace
28
from typing import Dict, List, Optional, Tuple, Union
29

30
import psutil
31
import setproctitle
32
import torch
33
import zmq
34
from torch.distributed import barrier
35

36
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
37
from sglang.srt.configs.model_config import ModelConfig
38
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
39
40
41
42
from sglang.srt.constrained.base_grammar_backend import (
    INVALID_GRAMMAR_OBJ,
    create_grammar_backend,
)
Byron Hsu's avatar
Byron Hsu committed
43
44
45
46
47
from sglang.srt.disaggregation.decode import (
    DecodePreallocQueue,
    DecodeTransferQueue,
    SchedulerDisaggregationDecodeMixin,
)
48
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
Byron Hsu's avatar
Byron Hsu committed
49
50
51
52
53
54
from sglang.srt.disaggregation.prefill import (
    PrefillBootstrapQueue,
    SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
55
    MetadataBuffers,
Byron Hsu's avatar
Byron Hsu committed
56
    ReqToMetadataIdxAllocator,
57
    TransferBackend,
58
    prepare_abort,
Byron Hsu's avatar
Byron Hsu committed
59
)
60
from sglang.srt.distributed import get_pp_group, get_world_group
fzyzcjy's avatar
fzyzcjy committed
61
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
xm:D's avatar
xm:D committed
62
63
64
65
66
from sglang.srt.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
67
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
68
69
70
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
    AbortReq,
71
    CloseSessionReqInput,
72
    ExpertDistributionReq,
73
    ExpertDistributionReqOutput,
74
75
    FlushCacheReqInput,
    FlushCacheReqOutput,
76
77
    GetInternalStateReq,
    GetInternalStateReqOutput,
78
79
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
80
    HealthCheckOutput,
81
82
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
83
84
    LoadLoRAAdapterReqInput,
    LoadLoRAAdapterReqOutput,
85
86
    OpenSessionReqInput,
    OpenSessionReqOutput,
87
    ProfileReq,
88
89
    ProfileReqOutput,
    ProfileReqType,
90
91
92
93
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
94
95
    RpcReqInput,
    RpcReqOutput,
96
97
    SetInternalStateReq,
    SetInternalStateReqOutput,
98
99
    SlowDownReqInput,
    SlowDownReqOutput,
100
101
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
102
103
    UnloadLoRAAdapterReqInput,
    UnloadLoRAAdapterReqOutput,
Chayenne's avatar
Chayenne committed
104
105
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
106
107
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
108
109
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
110
)
111
from sglang.srt.managers.mm_utils import init_embedding_cache
112
113
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
Mick's avatar
Mick committed
114
    MultimodalInputs,
115
116
    Req,
    ScheduleBatch,
117
    global_server_args_dict,
118
)
119
120
121
122
123
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
124
125
126
from sglang.srt.managers.scheduler_output_processor_mixin import (
    SchedulerOutputProcessorMixin,
)
127
from sglang.srt.managers.session_controller import Session
128
from sglang.srt.managers.tp_worker import TpModelWorker
129
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
130
from sglang.srt.managers.utils import validate_input_length
tarinkk's avatar
tarinkk committed
131
132
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
133
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
134
from sglang.srt.mem_cache.radix_cache import RadixCache
135
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
Lianmin Zheng's avatar
Lianmin Zheng committed
136
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
137
from sglang.srt.reasoning_parser import ReasoningParser
138
from sglang.srt.server_args import PortArgs, ServerArgs
139
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
140
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
141
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
142
from sglang.srt.utils import (
143
    DeepEPMode,
144
    DynamicGradMode,
145
    broadcast_pyobj,
fzyzcjy's avatar
fzyzcjy committed
146
    configure_gc_logger,
147
    configure_logger,
Lianmin Zheng's avatar
Lianmin Zheng committed
148
    disable_request_logging,
149
    get_available_gpu_memory,
150
    get_bool_env_var,
151
    get_zmq_socket,
152
    is_cpu,
Lianmin Zheng's avatar
Lianmin Zheng committed
153
    kill_itself_when_parent_died,
154
    point_to_point_pyobj,
155
    pyspy_dump_schedulers,
156
157
    require_mlp_sync,
    require_mlp_tp_gather,
158
    set_gpu_proc_affinity,
159
160
161
    set_random_seed,
    suppress_other_loggers,
)
162
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
163
164
165

logger = logging.getLogger(__name__)

166
# Test retract decode for debugging purposes
167
168
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
169
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
170

171
172
_is_cpu = is_cpu()

173

174
175
@dataclass
class GenerationBatchResult:
176
177
178
    logits_output: Optional[LogitsProcessorOutput]
    pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
    next_token_ids: Optional[List[int]]
179
180
    extend_input_len_per_req: List[int]
    extend_logprob_start_len_per_req: List[int]
181
    bid: int
182
    can_run_cuda_graph: bool
183
184
185
186
187
188
189
190


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


191
192
193
194
195
196
197
198
199
200
201
202
class KvMetrics:
    def __init__(self):
        self.request_active_slots = None
        self.request_total_slots = None
        self.kv_active_blocks = None
        self.kv_total_blocks = None
        self.num_requests_waiting = None
        self.gpu_cache_usage_perc = None
        self.gpu_prefix_cache_hit_rate = None
        self.data_parallel_rank = None


203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
class IdleSleeper:
    """
    In setups which have long inactivity periods it is desirable to reduce
    system power consumption when sglang does nothing. This would lead not only
    to power savings, but also to more CPU thermal headroom when a request
    eventually comes. This is important in cases when multiple GPUs are connected
    as each GPU would otherwise pin one thread at 100% CPU usage.

    The simplest solution is to use zmq.Poller on all sockets that may receive
    data that needs handling immediately.
    """

    def __init__(self, sockets):
        self.poller = zmq.Poller()
        for s in sockets:
            self.poller.register(s, zmq.POLLIN)

    def maybe_sleep(self):
        self.poller.poll(1000)


Byron Hsu's avatar
Byron Hsu committed
224
225
226
227
228
class Scheduler(
    SchedulerOutputProcessorMixin,
    SchedulerDisaggregationDecodeMixin,
    SchedulerDisaggregationPrefillMixin,
):
229
230
231
232
233
234
235
236
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
237
        pp_rank: int,
238
        dp_rank: Optional[int],
239
240
    ):
        # Parse args
241
        self.server_args = server_args
242
        self.tp_rank = tp_rank
243
        self.pp_rank = pp_rank
244
        self.dp_rank = dp_rank
245
        self.tp_size = server_args.tp_size
246
247
        self.pp_size = server_args.pp_size
        self.dp_size = server_args.dp_size
248
249
250
        self.schedule_policy = server_args.schedule_policy
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
251
        self.enable_overlap = not server_args.disable_overlap_schedule
252
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
253
        self.enable_metrics = server_args.enable_metrics
254
        self.enable_kv_cache_events = server_args.kv_events_config is not None
255
        self.stream_interval = server_args.stream_interval
256
257
258
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
259
260
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
Lianmin Zheng's avatar
Lianmin Zheng committed
261
        self.page_size = server_args.page_size
262
263
        self.dp_size = server_args.dp_size
        self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
264
265
266
267
268
269
270
271
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

272
273
        # Init inter-process communication
        context = zmq.Context(2)
274
275
        self.idle_sleeper = None

276
        if self.pp_rank == 0 and self.attn_tp_rank == 0:
277
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
278
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
279
            )
280
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
281
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
282
            )
283
284
285
            self.send_metrics_from_scheduler = get_zmq_socket(
                context, zmq.PUSH, port_args.metrics_ipc_name, False
            )
286

287
            if server_args.skip_tokenizer_init:
288
                # Directly send to the TokenizerManager
289
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
290
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
291
292
                )
            else:
293
                # Send to the DetokenizerManager
294
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
295
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
296
                )
297
298
299
300

            self.recv_from_rpc = get_zmq_socket(
                context, zmq.DEALER, port_args.rpc_ipc_name, False
            )
301
302
303
304
305
306
307
            if self.server_args.sleep_on_idle:
                self.idle_sleeper = IdleSleeper(
                    [
                        self.recv_from_tokenizer,
                        self.recv_from_rpc,
                    ]
                )
308
        else:
309
            self.recv_from_tokenizer = None
310
            self.recv_from_rpc = None
311
            self.send_metrics_from_scheduler = None
312
313
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
314
315

        # Init tokenizer
316
        self.init_tokenizer()
317

318
319
320
321
322
323
324
325
326
        # Set reasoning_parser and think_end_id if --reasoning_parser is enabled
        if self.server_args.reasoning_parser and self.tokenizer:
            reasoning_parser = ReasoningParser(
                model_type=self.server_args.reasoning_parser, stream_reasoning=False
            )
            self.tokenizer.think_end_id = self.tokenizer.encode(
                reasoning_parser.detector.think_end_token, add_special_tokens=False
            )[0]

327
328
329
330
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
331

332
        # Launch a tensor parallel worker
333
        if self.enable_overlap:
334
            TpWorkerClass = TpModelWorkerClient
335
336
        else:
            TpWorkerClass = TpModelWorker
337

338
        self.tp_worker = TpWorkerClass(
339
            server_args=server_args,
340
341
            gpu_id=gpu_id,
            tp_rank=tp_rank,
342
            pp_rank=pp_rank,
343
            dp_rank=dp_rank,
344
            nccl_port=port_args.nccl_port,
345
        )
346

347
        # Launch a draft worker for speculative decoding
348
349
350
351
352
353
354
355
356
357
358
359
360
361
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
        else:
            self.draft_worker = None

362
        # Get token and memory info from the model worker
363
364
365
366
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
367
            self.max_req_len,
368
369
            self.max_req_input_len,
            self.random_seed,
370
            self.device,
371
372
373
374
375
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
376
377
378
379
380
381
382
383
        if global_server_args_dict["max_micro_batch_size"] is None:
            global_server_args_dict["max_micro_batch_size"] = max(
                self.max_running_requests // server_args.pp_size, 1
            )

        self.tp_group = self.tp_worker.get_tp_group()
        self.tp_cpu_group = self.tp_group.cpu_group
        self.attn_tp_group = self.tp_worker.get_attention_tp_group()
384
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
385
386
387
        self.pp_group = get_pp_group()
        self.world_group = get_world_group()

388
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
389
        global_server_args_dict.update(worker_global_server_args_dict)
390
        set_random_seed(self.random_seed)
391

392
        # Print debug info
393
        if tp_rank == 0:
394
395
396
            avail_mem = get_available_gpu_memory(
                self.device, self.gpu_id, empty_cache=False
            )
397
398
399
400
401
            logger.info(
                f"max_total_num_tokens={self.max_total_num_tokens}, "
                f"chunked_prefill_size={server_args.chunked_prefill_size}, "
                f"max_prefill_tokens={self.max_prefill_tokens}, "
                f"max_running_requests={self.max_running_requests}, "
402
403
                f"context_len={self.model_config.context_len}, "
                f"available_gpu_mem={avail_mem:.2f} GB"
404
            )
405

Lianmin Zheng's avatar
Lianmin Zheng committed
406
        # Init memory pool and cache
407
        self.init_memory_pool_and_cache()
408
409
410

        # Init running status
        self.waiting_queue: List[Req] = []
411
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
412
        self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
413
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
414
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
415
        # The last forward batch
416
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
417
418
        self.forward_ct = 0
        self.forward_ct_decode = 0
419
        self.num_generated_tokens = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
420
        self.last_prefill_tokens = 0
421
422
        self.last_decode_stats_tic = time.perf_counter()
        self.last_prefill_stats_tic = time.perf_counter()
423
        self.return_health_check_ct = 0
424
425
426
427
428
        self.num_retracted_reqs: int = 0
        self.num_paused_reqs: int = 0
        self.kv_transfer_speed_gb_s: float = 0.0
        self.kv_transfer_latency_ms: float = 0.0
        self.sessions: Dict[str, Session] = {}
429
        self.current_stream = torch.get_device_module(self.device).current_stream()
430
431
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
432
        self.forward_sleep_time = None
433

434
435
        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
436
437
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
438
        self.chunked_req = None
439
440
441
442
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
443
        # Init the grammar backend for constrained generation
444
        self.grammar_queue: List[Req] = []
445
        if not server_args.skip_tokenizer_init:
446
447
448
            self.grammar_backend = create_grammar_backend(
                server_args, self.tokenizer, self.model_config.vocab_size
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
449
450
        else:
            self.grammar_backend = None
451

452
        # Init schedule policy and new token estimation
453
        self.policy = SchedulePolicy(
Lianmin Zheng's avatar
Lianmin Zheng committed
454
455
456
            self.schedule_policy,
            self.tree_cache,
            self.enable_hierarchical_cache,
457
        )
458
459
460
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
461
462
        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
463
464
            * server_args.schedule_conservativeness,
            1.0,
465
        )
466
467
468
469
470
471
472
473
474
475
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
476
477
478
479
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
480
        self.parent_process = psutil.Process().parent()
481
482

        # Init memory saver, profiler and metric stats
483
484
485
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )
486
        self.init_profier()
487
        self.init_metrics()
488
        self.init_kv_events(server_args.kv_events_config)
489

490
491
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
492
493
494
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
495
                (FlushCacheReqInput, self.flush_cache_wrapped),
496
                (AbortReq, self.abort_request),
497
498
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
499
500
501
502
503
504
505
506
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
507
508
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
509
                (SlowDownReqInput, self.slow_down),
510
                (ProfileReq, self.profile),
511
                (GetInternalStateReq, self.get_internal_state),
512
                (SetInternalStateReq, self.set_internal_state),
513
                (RpcReqInput, self.handle_rpc_request),
514
                (ExpertDistributionReq, self.expert_distribution_handle),
515
516
                (LoadLoRAAdapterReqInput, self.load_lora_adapter),
                (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
517
518
519
            ]
        )

520
        # Init disaggregation
Byron Hsu's avatar
Byron Hsu committed
521
522
523
524
525
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
        self.init_disaggregation()

fzyzcjy's avatar
fzyzcjy committed
526
527
528
        if get_bool_env_var("SGLANG_GC_LOG"):
            configure_gc_logger()

529
530
531
532
    def maybe_sleep_on_idle(self):
        if self.idle_sleeper is not None:
            self.idle_sleeper.maybe_sleep()

533
534
    def init_tokenizer(self):
        server_args = self.server_args
Lianmin Zheng's avatar
Lianmin Zheng committed
535

536
        self.model_config = ModelConfig.from_server_args(server_args)
537
        self.is_generation = self.model_config.is_generation
538

539
540
541
542
543
544
545
546
547
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if self.model_config.is_multimodal:
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
548
                    use_fast=not server_args.disable_fast_image_processor,
549
                )
xm:D's avatar
xm:D committed
550
                self.tokenizer = get_tokenizer_from_processor(self.processor)
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )

    def init_memory_pool_and_cache(self):
        server_args = self.server_args

        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            self.tp_worker.get_memory_pool()
        )

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
tarinkk's avatar
tarinkk committed
570
571
572
573
574
            if self.model_config.is_hybrid:
                ChunkCacheClass = SWAChunkCache
            else:
                ChunkCacheClass = ChunkCache
            self.tree_cache = ChunkCacheClass(
575
576
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
577
                page_size=self.page_size,
578
579
580
581
582
583
            )
        else:
            if self.enable_hierarchical_cache:
                self.tree_cache = HiRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
584
585
586
587
588
                    tp_cache_group=(
                        self.attn_tp_cpu_group
                        if self.server_args.enable_dp_attention
                        else self.tp_cpu_group
                    ),
589
                    page_size=self.page_size,
590
                    hicache_ratio=server_args.hicache_ratio,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
591
592
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
593
                )
594
595
596
597
                self.tp_worker.register_hicache_layer_transfer_counter(
                    self.tree_cache.cache_controller.layer_done_counter
                )

598
599
600
601
            else:
                self.tree_cache = RadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
602
                    page_size=self.page_size,
603
                    disable=server_args.disable_radix_cache,
604
                    enable_kv_cache_events=self.enable_kv_cache_events,
605
606
607
608
609
610
611
612
613
614
615
616
                )

        self.decode_mem_cache_buf_multiplier = (
            1
            if self.spec_algorithm.is_none()
            else (
                server_args.speculative_num_draft_tokens
                + (
                    server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                )
            )
617
        )
618

619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
    def init_profier(self):
        self.torch_profiler = None
        self.torch_profiler_output_dir: Optional[str] = None
        self.profiler_activities: Optional[List[str]] = None
        self.profile_id: Optional[str] = None
        self.profiler_target_forward_ct: Optional[int] = None
        self.profiler_target_prefill_ct: Optional[int] = None
        self.profiler_target_decode_ct: Optional[int] = None
        self.profiler_prefill_ct: Optional[int] = None
        self.profiler_decode_ct: Optional[int] = None
        self.profile_by_stage: bool = False
        self.profile_steps: Optional[int] = None
        self.profile_in_progress: bool = False
        self.rpd_profiler = None

634
635
    def init_metrics(self):
        self.last_gen_throughput: float = 0.0
Lianmin Zheng's avatar
Lianmin Zheng committed
636
        self.last_input_throughput: float = 0.0
637
638
639
640
641
642
643
644
645
646
647
648
649
650
        self.step_time_dict = defaultdict(list)  # Dict[batch size -> step time]
        self.spec_num_total_accepted_tokens = 0
        self.spec_num_total_forward_ct = 0
        self.cum_spec_accept_length = 0
        self.cum_spec_accept_count = 0
        self.stats = SchedulerStats()
        if self.enable_metrics:
            engine_type = "unified"
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    "engine_type": engine_type,
                },
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
651

652
653
    def init_kv_events(self, kv_events_config: Optional[str]):
        if self.enable_kv_cache_events:
654
655
656
            self.kv_event_publisher = EventPublisherFactory.create(
                kv_events_config, self.attn_dp_rank
            )
657

Byron Hsu's avatar
Byron Hsu committed
658
    def init_disaggregation(self):
659
660
661
662
        self.transfer_backend = TransferBackend(
            self.server_args.disaggregation_transfer_backend
        )

Byron Hsu's avatar
Byron Hsu committed
663
664
665
666
        if (
            self.disaggregation_mode == DisaggregationMode.DECODE
        ):  # *2 for the headroom.
            buffer_size = (self.req_to_token_pool.size) * 2
Byron Hsu's avatar
Byron Hsu committed
667
            self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
Byron Hsu's avatar
Byron Hsu committed
668
669
                buffer_size
            )
670
671
            self.disagg_metadata_buffers = MetadataBuffers(
                buffer_size,
672
673
                hidden_size=self.model_config.hf_text_config.hidden_size,
                dtype=self.model_config.dtype,
674
675
                custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
            )
Byron Hsu's avatar
Byron Hsu committed
676
677
678

            # The decode requests polling kv cache
            self.disagg_decode_transfer_queue = DecodeTransferQueue(
679
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
680
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
681
                tp_rank=self.tp_rank,
682
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
683
684
                scheduler=self,
                tree_cache=self.tree_cache,
Byron Hsu's avatar
Byron Hsu committed
685
686
687
688
689
690
            )

            # The decode requests pending for pre-allocation
            self.disagg_decode_prealloc_queue = DecodePreallocQueue(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Byron Hsu's avatar
Byron Hsu committed
691
692
693
694
695
                draft_token_to_kv_pool=(
                    None
                    if self.draft_worker is None
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
696
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
697
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
698
699
700
                scheduler=self,
                transfer_queue=self.disagg_decode_transfer_queue,
                tree_cache=self.tree_cache,
701
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
702
703
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
704
705
                dp_size=self.server_args.dp_size,
                gpu_id=self.gpu_id,
Byron Hsu's avatar
Byron Hsu committed
706
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
707
708
                max_total_num_tokens=self.max_total_num_tokens,
                prefill_pp_size=self.server_args.disaggregation_prefill_pp,
709
                num_reserved_decode_tokens=self.server_args.num_reserved_decode_tokens,
710
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
711
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
712

Byron Hsu's avatar
Byron Hsu committed
713
714
715
        elif self.disaggregation_mode == DisaggregationMode.PREFILL:
            # *2 for the headroom.
            buffer_size = self.max_running_requests * 2
Byron Hsu's avatar
Byron Hsu committed
716
            self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
Byron Hsu's avatar
Byron Hsu committed
717
718
                buffer_size
            )
719
720
            self.disagg_metadata_buffers = MetadataBuffers(
                buffer_size,
721
722
                hidden_size=self.model_config.hf_text_config.hidden_size,
                dtype=self.model_config.dtype,
723
724
                custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
            )
Byron Hsu's avatar
Byron Hsu committed
725

Liangsheng Yin's avatar
Liangsheng Yin committed
726
            self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
Byron Hsu's avatar
Byron Hsu committed
727
                token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
Byron Hsu's avatar
Byron Hsu committed
728
729
730
731
732
                draft_token_to_kv_pool=(
                    None
                    if self.draft_worker is None
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
733
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
734
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
735
736
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
Byron Hsu's avatar
Byron Hsu committed
737
                gpu_id=self.gpu_id,
Byron Hsu's avatar
Byron Hsu committed
738
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
739
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
740
741
742
                max_total_num_tokens=self.max_total_num_tokens,
                decode_tp_size=self.server_args.disaggregation_decode_tp,
                decode_dp_size=self.server_args.disaggregation_decode_dp,
743
                scheduler=self,
Byron Hsu's avatar
Byron Hsu committed
744
745
746
                pp_rank=self.pp_rank,
                pp_size=self.pp_size,
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
747
748
            )
            # The prefill requests that are in the middle of kv sending
749
            self.disagg_prefill_inflight_queue: List[Req] = []
Byron Hsu's avatar
Byron Hsu committed
750

751
    @DynamicGradMode()
752
    def event_loop_normal(self):
753
        """A normal scheduler loop."""
754
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
755
756
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
757

758
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
759
            self.cur_batch = batch
760
761
762
763

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
764
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
765
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
766
                self.check_memory()
767
                self.new_token_ratio = self.init_new_token_ratio
768
                self.maybe_sleep_on_idle()
769
770

            self.last_batch = batch
771

772
    @DynamicGradMode()
Lianmin Zheng's avatar
Lianmin Zheng committed
773
    def event_loop_overlap(self):
774
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
775
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
776
777
778
779
780
781
782

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
783

Lianmin Zheng's avatar
Lianmin Zheng committed
784
            if batch:
785
                batch.launch_done = threading.Event()
Lianmin Zheng's avatar
Lianmin Zheng committed
786
                result = self.run_batch(batch)
787
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
788

789
                if self.last_batch is None:
790
                    # Create a dummy first batch to start the pipeline for overlap schedule.
791
792
793
794
795
796
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
797
                    self.process_batch_result(tmp_batch, None, batch.launch_done)
798

Lianmin Zheng's avatar
Lianmin Zheng committed
799
            if self.last_batch:
800
                # Process the results of the last batch
801
                tmp_batch, tmp_result = self.result_queue.popleft()
802
803
804
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
805
806
807
808
                # NOTE: we should use current launched batch's launch_done event Instead of the last batch's
                self.process_batch_result(
                    tmp_batch, tmp_result, batch.launch_done if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
809
            elif batch is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
810
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
811
                self.check_memory()
812
                self.new_token_ratio = self.init_new_token_ratio
813
                self.maybe_sleep_on_idle()
Lianmin Zheng's avatar
Lianmin Zheng committed
814
815
816

            self.last_batch = batch

817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
    @DynamicGradMode()
    def event_loop_pp(self):
        """A non-overlap scheduler loop for pipeline parallelism."""
        mbs = [None] * self.pp_size
        last_mbs = [None] * self.pp_size
        self.running_mbs = [
            ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
        ]
        bids = [None] * self.pp_size
        pp_outputs: Optional[PPProxyTensors] = None
        while True:
            server_is_idle = True
            for mb_id in range(self.pp_size):
                self.running_batch = self.running_mbs[mb_id]
                self.last_batch = last_mbs[mb_id]

                recv_reqs = self.recv_requests()
                self.process_input_requests(recv_reqs)
                mbs[mb_id] = self.get_next_batch_to_run()
                self.running_mbs[mb_id] = self.running_batch

                self.cur_batch = mbs[mb_id]
                if self.cur_batch:
                    server_is_idle = False
                    result = self.run_batch(self.cur_batch)

843
                # (last rank) send the outputs to the next step
844
845
846
847
848
849
                if self.pp_group.is_last_rank:
                    if self.cur_batch:
                        next_token_ids, bids[mb_id] = (
                            result.next_token_ids,
                            result.bid,
                        )
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
                        if self.cur_batch.return_logprob:
                            pp_outputs = PPProxyTensors(
                                {
                                    "next_token_ids": next_token_ids,
                                    "extend_input_len_per_req": result.extend_input_len_per_req,
                                    "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
                                }
                                | (
                                    {
                                        f"logits_output.{k}": v
                                        for k, v in result.logits_output.__dict__.items()
                                    }
                                    if result.logits_output is not None
                                    else {}
                                )
                            )
                        else:
                            pp_outputs = PPProxyTensors(
                                {
                                    "next_token_ids": next_token_ids,
                                }
                            )
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
                        # send the output from the last round to let the next stage worker run post processing
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                # receive outputs and post-process (filter finished reqs) the coming microbatch
                next_mb_id = (mb_id + 1) % self.pp_size
                next_pp_outputs = None
                if mbs[next_mb_id] is not None:
                    next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
                        self.pp_group.recv_tensor_dict(
                            all_gather_group=self.attn_tp_group
                        )
                    )
                    mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
888
889
890
891
892
893
894
895
896
                    logits_output_args = {
                        k[len("logits_output.") :]: v
                        for k, v in next_pp_outputs.tensors.items()
                        if k.startswith("logits_output.")
                    }
                    if len(logits_output_args) > 0:
                        logits_output = LogitsProcessorOutput(**logits_output_args)
                    else:
                        logits_output = None
897
                    output_result = GenerationBatchResult(
898
                        logits_output=logits_output,
899
900
                        pp_hidden_states_proxy_tensors=None,
                        next_token_ids=next_pp_outputs["next_token_ids"],
901
902
903
904
905
906
                        extend_input_len_per_req=next_pp_outputs.tensors.get(
                            "extend_input_len_per_req", None
                        ),
                        extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
                            "extend_logprob_start_len_per_req", None
                        ),
907
                        bid=bids[next_mb_id],
908
                        can_run_cuda_graph=result.can_run_cuda_graph,
909
910
911
912
                    )
                    self.process_batch_result(mbs[next_mb_id], output_result)
                    last_mbs[next_mb_id] = mbs[next_mb_id]

913
                # (not last rank)
914
915
916
                if not self.pp_group.is_last_rank:
                    if self.cur_batch:
                        bids[mb_id] = result.bid
917
918
                    # carry the outputs to the next stage
                    # send the outputs from the last round to let the next stage worker run post processing
919
920
921
922
923
924
925
                    if pp_outputs:
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                    # send out reqs to the next stage
926
                    dp_offset = self.attn_dp_rank * self.attn_tp_size
927
928
929
930
                    if self.attn_tp_rank == 0:
                        point_to_point_pyobj(
                            recv_reqs,
                            self.pp_rank * self.tp_size + dp_offset,
931
                            self.world_group.device_group,
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
                            self.pp_rank * self.tp_size + dp_offset,
                            (self.pp_rank + 1) * self.tp_size + dp_offset,
                        )

                    # send out proxy tensors to the next stage
                    if self.cur_batch:
                        self.pp_group.send_tensor_dict(
                            result.pp_hidden_states_proxy_tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                pp_outputs = next_pp_outputs

            # When the server is idle, self-check and re-init some states
            if server_is_idle:
                self.check_memory()
                self.new_token_ratio = self.init_new_token_ratio
949
                self.maybe_sleep_on_idle()
950

951
952
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
        if self.pp_rank == 0:
            if self.attn_tp_rank == 0:
                recv_reqs = []

                while True:
                    try:
                        recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_req)

                while True:
                    try:
                        recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_rpc)
            else:
                recv_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
972
        else:
973
            if self.attn_tp_rank == 0:
974
                dp_offset = self.attn_dp_rank * self.attn_tp_size
975
976
977
                recv_reqs = point_to_point_pyobj(
                    [],
                    self.pp_rank * self.tp_size + dp_offset,
978
                    self.world_group.device_group,
979
980
981
982
983
                    (self.pp_rank - 1) * self.tp_size + dp_offset,
                    self.pp_rank * self.tp_size + dp_offset,
                )
            else:
                recv_reqs = None
984

985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                work_reqs = broadcast_pyobj(
                    work_reqs,
1008
                    self.attn_tp_group.rank,
1009
                    self.attn_tp_cpu_group,
1010
                    src=self.attn_tp_group.ranks[0],
1011
1012
1013
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
1014
1015
1016
1017
                    control_reqs,
                    self.tp_group.rank,
                    self.tp_cpu_group,
                    src=self.tp_group.ranks[0],
1018
1019
1020
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
1021
1022
1023
1024
1025
1026
            recv_reqs = broadcast_pyobj(
                recv_reqs,
                self.tp_group.rank,
                self.tp_cpu_group,
                src=self.tp_group.ranks[0],
            )
1027
1028
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
1029
    def process_input_requests(self, recv_reqs: List):
1030
        for recv_req in recv_reqs:
1031
1032
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
Lianmin Zheng's avatar
Lianmin Zheng committed
1033
                self.chunked_req is not None or not self.running_batch.is_empty()
1034
1035
1036
1037
            ):
                self.return_health_check_ct += 1
                continue

1038
            output = self._request_dispatcher(recv_req)
1039
            if output is not None:
1040
1041
1042
1043
1044
                if isinstance(output, RpcReqOutput):
                    if self.recv_from_rpc is not None:
                        self.recv_from_rpc.send_pyobj(output)
                else:
                    self.send_to_tokenizer.send_pyobj(output)
1045
1046
1047
1048
1049

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
1050
        # Create a new request
1051
1052
1053
1054
1055
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
1056
1057
1058
1059
1060
1061
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

1062
1063
1064
1065
            if recv_req.bootstrap_port is None:
                # Use default bootstrap port
                recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port

1066
1067
1068
1069
1070
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
1071
1072
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
1073
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
1074
                stream=recv_req.stream,
1075
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
1076
                input_embeds=recv_req.input_embeds,
Lianmin Zheng's avatar
Lianmin Zheng committed
1077
                custom_logit_processor=recv_req.custom_logit_processor,
1078
                return_hidden_states=recv_req.return_hidden_states,
1079
                eos_token_ids=self.model_config.hf_eos_token_id,
1080
                bootstrap_host=recv_req.bootstrap_host,
1081
                bootstrap_port=recv_req.bootstrap_port,
1082
                bootstrap_room=recv_req.bootstrap_room,
1083
                data_parallel_rank=recv_req.data_parallel_rank,
1084
1085
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
1086

1087
1088
1089
            if self.disaggregation_mode != DisaggregationMode.NULL:
                # Invalid request for disaggregated mode
                if recv_req.bootstrap_room is None:
1090
                    error_msg = (
1091
1092
1093
                        f"Invalid request: Disaggregated request received without "
                        f"boostrap room id. {req.rid=}"
                    )
1094
1095
                    logger.error(error_msg)
                    prepare_abort(req, error_msg)
1096
1097
1098
                    self.stream_output([req], req.return_logprob)
                    return

1099
1100
1101
1102
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
1103
                req.set_finish_with_abort(
1104
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
1105
                )
1106
                self._add_request_to_queue(req)
1107
1108
                return
        else:
1109
1110
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
1111
            req = session.create_req(recv_req, self.tokenizer)
1112
            if isinstance(req.finished_reason, FINISH_ABORT):
1113
                self._add_request_to_queue(req)
1114
                return
1115

1116
        # Handle multimodal inputs
Mick's avatar
Mick committed
1117
1118
        if recv_req.mm_inputs is not None:
            image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
1119
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
1120
            req.origin_input_ids = self.pad_input_ids_func(
1121
                req.origin_input_ids, image_inputs
1122
            )
1123
            req.extend_image_inputs(image_inputs)
1124

1125
            if len(req.origin_input_ids) >= self.max_req_input_len:
1126
1127
1128
1129
1130
                req.set_finish_with_abort(
                    error_msg=(
                        "Multimodal prompt is too long after expanding multimodal tokens. "
                        f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                    )
1131
                )
1132
                self._add_request_to_queue(req)
1133
1134
                return

1135
        # Validate prompt length
1136
1137
1138
1139
1140
1141
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
1142
            req.set_finish_with_abort(error_msg)
1143
            self._add_request_to_queue(req)
1144
            return
1145

1146
        # Copy more attributes
1147
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
1148
1149
1150
1151
1152
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(req.origin_input_ids) - 1
        else:
            req.logprob_start_len = recv_req.logprob_start_len

1153
        if req.logprob_start_len >= len(req.origin_input_ids):
1154
            error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
1155
            req.logprob_start_len = len(req.origin_input_ids) - 1
1156
            req.set_finish_with_abort(error_msg)
1157
1158
1159
            self._add_request_to_queue(req)
            return

1160
1161
1162
1163
1164
1165
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
1166
            self.max_req_len - len(req.origin_input_ids) - 1,
1167
1168
        )

1169
1170
1171
1172
1173
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
1174
            or req.sampling_params.ebnf is not None
1175
            or req.sampling_params.structural_tag is not None
1176
1177
1178
1179
1180
1181
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
1182
1183
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
1184
1185
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
1186

1187
1188
1189
1190
1191
            value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
            req.grammar = value

            if not cache_hit:
                req.grammar_key = key
1192
                add_to_grammar_queue = True
1193
1194
1195
1196
            else:
                if value is INVALID_GRAMMAR_OBJ:  # We hit a cached invalid grammar.
                    error_msg = f"Invalid grammar request with cache hit: {key=}"
                    req.set_finish_with_abort(error_msg)
1197
1198

        if add_to_grammar_queue:
1199
            req.queue_time_start = time.perf_counter()
1200
1201
            self.grammar_queue.append(req)
        else:
1202
1203
1204
            self._add_request_to_queue(req)

    def _add_request_to_queue(self, req: Req):
1205
        req.queue_time_start = time.perf_counter()
Byron Hsu's avatar
Byron Hsu committed
1206
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
Byron Hsu's avatar
Byron Hsu committed
1207
1208
1209
            self.disagg_prefill_bootstrap_queue.add(
                req, self.model_config.num_key_value_heads
            )
Byron Hsu's avatar
Byron Hsu committed
1210
1211
1212
1213
1214
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.add(req)
        else:
            self.waiting_queue.append(req)

1215
    def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
1216
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
Byron Hsu's avatar
Byron Hsu committed
1217
1218
1219
            self.disagg_prefill_bootstrap_queue.extend(
                reqs, self.model_config.num_key_value_heads
            )
1220
1221
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            # If this is a decode server, we put the request to the decode pending prealloc queue
1222
            self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
Byron Hsu's avatar
Byron Hsu committed
1223
1224
        else:
            self.waiting_queue.extend(reqs)
1225
1226
1227

    def handle_embedding_request(
        self,
1228
        recv_req: TokenizedEmbeddingReqInput,
1229
1230
1231
1232
1233
1234
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
woodx's avatar
woodx committed
1235
            token_type_ids=recv_req.token_type_ids,
1236
1237
1238
        )
        req.tokenizer = self.tokenizer

1239
1240
        # Handle multimodal inputs
        if recv_req.image_inputs is not None:
Mick's avatar
Mick committed
1241
            image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
1242
1243
1244
1245
1246
1247
1248
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
            req.origin_input_ids = self.pad_input_ids_func(
                req.origin_input_ids, image_inputs
            )
            req.extend_image_inputs(image_inputs)

            if len(req.origin_input_ids) >= self.max_req_input_len:
1249
1250
1251
1252
1253
                req.set_finish_with_abort(
                    error_msg=(
                        "Multimodal prompt is too long after expanding multimodal tokens. "
                        f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                    )
1254
                )
1255
                self._add_request_to_queue(req)
1256
1257
                return

1258
        # Validate prompts length
1259
        error_msg = validate_input_length(
1260
1261
1262
1263
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
1264
        if error_msg:
1265
            self._add_request_to_queue(req)
1266
            return
1267

1268
1269
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
1270
        self._add_request_to_queue(req)
1271

1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
    def _emit_kv_metrics(self):
        kv_metrics = KvMetrics()
        kv_metrics.request_active_slots = self.stats.num_running_reqs
        kv_metrics.request_total_slots = self.max_running_requests
        kv_metrics.kv_active_blocks = int(
            self.stats.token_usage * self.max_total_num_tokens
        )
        kv_metrics.kv_total_blocks = self.max_total_num_tokens
        kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
        kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
        kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
        kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0

        if not self.send_metrics_from_scheduler.closed:
            self.send_metrics_from_scheduler.send_pyobj(kv_metrics)

1288
1289
1290
1291
    def log_prefill_stats(
        self,
        adder: PrefillAdder,
        can_run_list: List[Req],
1292
        running_bs: int,
1293
    ):
1294
1295
        gap_latency = time.perf_counter() - self.last_prefill_stats_tic
        self.last_prefill_stats_tic = time.perf_counter()
Liangsheng Yin's avatar
Liangsheng Yin committed
1296
1297
        self.last_input_throughput = self.last_prefill_tokens / gap_latency
        self.last_prefill_tokens = adder.log_input_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
1298

tarinkk's avatar
tarinkk committed
1299
1300
        usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
            self.tree_cache.evictable_size()
1301
1302
        )

1303
        num_new_seq = len(can_run_list)
1304
        f = (
1305
            f"Prefill batch. "
1306
            f"#new-seq: {num_new_seq}, "
1307
1308
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
tarinkk's avatar
tarinkk committed
1309
            f"{usage_msg}"
1310
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
1311
1312
1313
1314

        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
            f += f"#queue-req: {len(self.waiting_queue)}, "
fzyzcjy's avatar
fzyzcjy committed
1315
            f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
Liangsheng Yin's avatar
Liangsheng Yin committed
1316
            f += f"input throughput (token/s): {self.last_input_throughput:.2f} "
Liangsheng Yin's avatar
Liangsheng Yin committed
1317
        else:
Liangsheng Yin's avatar
Liangsheng Yin committed
1318
            f += f"#running-req: {running_bs}, "
Liangsheng Yin's avatar
Liangsheng Yin committed
1319
1320
            f += f"#queue-req: {len(self.waiting_queue)}"

1321
        logger.info(f)
1322
1323

        if self.enable_metrics:
1324
1325
1326
            cache_hit_rate = adder.log_hit_tokens / (
                adder.log_input_tokens + adder.log_hit_tokens
            )
1327
1328
1329
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
1330
1331
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.stats.cache_hit_rate = cache_hit_rate
1332
1333
1334
1335
1336
1337

            total_queue_latency = 0
            for req in can_run_list:
                total_queue_latency += req.queue_time_end - req.queue_time_start
            self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq

1338
            self.metrics_collector.log_stats(self.stats)
1339
            self._emit_kv_metrics()
1340
        self._publish_kv_events()
1341

1342
1343
1344
    def log_decode_stats(
        self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
    ):
1345
1346
        batch = running_batch or self.running_batch

1347
1348
        gap_latency = time.perf_counter() - self.last_decode_stats_tic
        self.last_decode_stats_tic = time.perf_counter()
1349
1350
        self.last_gen_throughput = self.num_generated_tokens / gap_latency
        self.num_generated_tokens = 0
1351
        num_running_reqs = len(batch.reqs)
tarinkk's avatar
tarinkk committed
1352
1353
        usage_msg, num_used = self.token_to_kv_pool_allocator.log_usage(
            self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1354
        )
1355
1356
1357
1358
1359

        if RECORD_STEP_TIME:
            self.step_time_dict[num_running_reqs].append(
                gap_latency / self.server_args.decode_log_interval
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1360

tarinkk's avatar
tarinkk committed
1361
        msg = f"Decode batch. " f"#running-req: {num_running_reqs}, " f"{usage_msg}"
Liangsheng Yin's avatar
Liangsheng Yin committed
1362

1363
        if self.spec_algorithm.is_none():
1364
            spec_accept_length = 0
1365
        else:
1366
            spec_accept_length = (
1367
1368
                self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
            )
1369
1370
            self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
            self.cum_spec_accept_count += self.spec_num_total_forward_ct
1371
            self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
1372
1373
1374
            msg += f"accept len: {spec_accept_length:.2f}, "

        if self.disaggregation_mode == DisaggregationMode.DECODE:
1375
            msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
1376
            msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
Liangsheng Yin's avatar
Liangsheng Yin committed
1377
1378

        msg += (
1379
            f"cuda graph: {can_run_cuda_graph}, "
Liangsheng Yin's avatar
Liangsheng Yin committed
1380
1381
1382
            f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
            f"#queue-req: {len(self.waiting_queue)}"
        )
1383
1384

        logger.info(msg)
1385
1386
1387
1388
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
1389
1390
            self.stats.cache_hit_rate = 0.0
            self.stats.gen_throughput = self.last_gen_throughput
1391
            self.stats.num_queue_reqs = len(self.waiting_queue)
1392
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1393
            self.stats.spec_accept_length = spec_accept_length
1394
            self.metrics_collector.log_stats(self.stats)
1395
            self._emit_kv_metrics()
1396
        self._publish_kv_events()
1397

Lianmin Zheng's avatar
Lianmin Zheng committed
1398
    def check_memory(self):
tarinkk's avatar
tarinkk committed
1399
1400
1401
1402
1403
        if isinstance(self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
            available_token_size = self.token_to_kv_pool_allocator.full_available_size()
        else:
            available_token_size = self.token_to_kv_pool_allocator.available_size()
        available_size = available_token_size + self.tree_cache.evictable_size()
1404
1405
1406
1407
1408
1409
1410
        protected_size = self.tree_cache.protected_size()
        memory_leak = available_size != (
            self.max_total_num_tokens
            if not self.enable_hierarchical_cache
            else self.max_total_num_tokens - protected_size
        )
        if memory_leak:
1411
            msg = (
1412
                "token_to_kv_pool_allocator memory leak detected! "
1413
                f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
tarinkk's avatar
tarinkk committed
1414
                f"{available_token_size=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1415
                f"{self.tree_cache.evictable_size()=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1416
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1417
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1418

1419
1420
1421
1422
1423
1424
1425
1426
        if self.disaggregation_mode == DisaggregationMode.DECODE:
            req_total_size = (
                self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
            )
        else:
            req_total_size = self.req_to_token_pool.size

        if len(self.req_to_token_pool.free_slots) != req_total_size:
1427
            msg = (
1428
                "req_to_token_pool memory leak detected!"
1429
1430
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1431
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1432
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1433

1434
1435
1436
        if (
            self.enable_metrics
            and self.attn_tp_rank == 0
1437
            and time.perf_counter() > self.metrics_collector.last_log_time + 30
1438
1439
1440
        ):
            # During idle time, also collect metrics every 30 seconds.
            num_used = self.max_total_num_tokens - (
1441
                self.token_to_kv_pool_allocator.available_size()
1442
1443
                + self.tree_cache.evictable_size()
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1444
            num_running_reqs = len(self.running_batch.reqs)
1445
1446
1447
1448
1449
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
1450
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1451
            self.metrics_collector.log_stats(self.stats)
1452
        self._publish_kv_events()
1453

1454
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1455
        # Merge the prefill batch into the running batch
1456
1457
1458
1459
1460
1461
1462
1463
        chunked_req_to_exclude = set()
        if self.chunked_req:
            # Move the chunked request out of the batch so that we can merge
            # only finished requests to running_batch.
            chunked_req_to_exclude.add(self.chunked_req)
            self.tree_cache.cache_unfinished_req(self.chunked_req)
            # chunked request keeps its rid but will get a new req_pool_idx
            self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
Lianmin Zheng's avatar
Lianmin Zheng committed
1464
        if self.last_batch and self.last_batch.forward_mode.is_extend():
1465
1466
1467
1468
            if self.last_batch.chunked_req is not None:
                # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
                # We need to discard it.
                chunked_req_to_exclude.add(self.last_batch.chunked_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1469

1470
            # Filter batch
1471
            last_bs = self.last_batch.batch_size()
1472
1473
1474
            self.last_batch.filter_batch(
                chunked_req_to_exclude=list(chunked_req_to_exclude)
            )
1475
            if self.last_batch.batch_size() < last_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1476
                self.running_batch.batch_is_full = False
1477

1478
            # Merge the new batch into the running batch
1479
            if not self.last_batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1480
                if self.running_batch.is_empty():
1481
1482
                    self.running_batch = self.last_batch
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1483
                    # Merge running_batch with prefill batch
1484
                    self.running_batch.merge_batch(self.last_batch)
1485

1486
        new_batch = self.get_new_batch_prefill()
1487

1488
1489
1490
1491
1492
        need_dp_attn_preparation = require_mlp_sync(self.server_args)

        if need_dp_attn_preparation and not self.spec_algorithm.is_none():
            # In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
            # We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
1493
            new_batch = self.prepare_mlp_sync_batch(new_batch)
1494
1495
1496
            need_dp_attn_preparation = new_batch is None

        if new_batch is not None:
1497
1498
1499
1500
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
1501
            if not self.running_batch.is_empty():
1502
                self.running_batch = self.update_running_batch(self.running_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1503
1504
1505
                ret = self.running_batch if not self.running_batch.is_empty() else None
            else:
                ret = None
1506

1507
1508
        # Handle DP attention
        if need_dp_attn_preparation:
1509
            ret = self.prepare_mlp_sync_batch(ret)
1510
1511

        return ret
1512

1513
1514
1515
1516
1517
1518
    def get_num_allocatable_reqs(self, running_bs):
        res = global_server_args_dict["max_micro_batch_size"] - running_bs
        if self.pp_size > 1:
            res = min(res, self.req_to_token_pool.available_size())
        return res

Lianmin Zheng's avatar
Lianmin Zheng committed
1519
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
1520
        # Check if the grammar is ready in the grammar queue
1521
        if self.grammar_queue:
1522
            self.move_ready_grammar_requests()
1523

Lianmin Zheng's avatar
Lianmin Zheng committed
1524
1525
        # Handle the cases where prefill is not allowed
        if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1526
            self.running_batch.batch_is_full or len(self.waiting_queue) == 0
1527
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1528
1529
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1530
        running_bs = len(self.running_batch.reqs)
1531
        # Ignore the check if self.chunked_req is not None.
1532
1533
1534
1535
1536
        # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
        # as the space for the chunked request has just been released.
        # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
        # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
        if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
Lianmin Zheng's avatar
Lianmin Zheng committed
1537
            self.running_batch.batch_is_full = True
1538
1539
            return None

1540
        if self.enable_hierarchical_cache:
1541
            self.tree_cache.check_hicache_events()
1542

1543
        # Get priority queue
1544
        self.policy.calc_priority(self.waiting_queue)
1545

Lianmin Zheng's avatar
Lianmin Zheng committed
1546
        # Prefill policy
1547
        adder = PrefillAdder(
1548
            self.page_size,
1549
            self.tree_cache,
1550
            self.token_to_kv_pool_allocator,
1551
1552
1553
1554
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1555
            running_bs if self.is_mixed_chunk else 0,
1556
1557
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1558
        if self.chunked_req is not None:
1559
1560
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1561

Lianmin Zheng's avatar
Lianmin Zheng committed
1562
        if self.lora_paths:
Lianmin Zheng's avatar
Lianmin Zheng committed
1563
1564
            lora_set = set([req.lora_path for req in self.running_batch.reqs])

1565
        # Get requests from the waiting queue to a new prefill batch
1566
1567
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1568
                self.lora_paths
1569
1570
1571
1572
1573
1574
1575
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1576
                self.running_batch.batch_is_full = True
1577
1578
                break

1579
            if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
Lianmin Zheng's avatar
Lianmin Zheng committed
1580
                self.running_batch.batch_is_full = True
1581
                break
1582

Byron Hsu's avatar
Byron Hsu committed
1583
1584
1585
1586
1587
1588
1589
            if self.disaggregation_mode == DisaggregationMode.PREFILL:
                # In prefill mode, prealloc queue and transfer queue can also take memory,
                # so we need to check if the available size for the actual available size.
                if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
                    self.running_batch.batch_is_full = True
                    break

1590
1591
            req.init_next_round_input(self.tree_cache)
            res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
1592

1593
1594
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1595
1596
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
Lianmin Zheng's avatar
Lianmin Zheng committed
1597
1598
                        self.running_batch.batch_is_full = len(
                            adder.can_run_list
1599
                        ) > 0 or (not self.running_batch.is_empty())
1600
                    else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1601
                        self.running_batch.batch_is_full = True
1602
1603
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1604
        # Update waiting queue
1605
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1606
1607
        if len(can_run_list) == 0:
            return None
1608
1609
1610
1611

        if self.enable_metrics:
            # only record queue time when enable_metrics is True to avoid overhead
            for req in can_run_list:
1612
                req.queue_time_end = time.perf_counter()
1613

Lianmin Zheng's avatar
Lianmin Zheng committed
1614
1615
1616
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1617

1618
1619
1620
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1621

1622
1623
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1624

1625
        # Print stats
1626
        if self.attn_tp_rank == 0:
1627
            self.log_prefill_stats(adder, can_run_list, running_bs)
1628

Lianmin Zheng's avatar
Lianmin Zheng committed
1629
        # Create a new batch
1630
1631
1632
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
1633
            self.token_to_kv_pool_allocator,
1634
            self.tree_cache,
1635
            self.model_config,
1636
            self.enable_overlap,
1637
            self.spec_algorithm,
1638
            self.server_args.enable_custom_logit_processor,
1639
            chunked_req=self.chunked_req,
1640
        )
1641
1642
        if self.enable_hierarchical_cache:
            # todo (zhiqiang): disable cuda graph execution if hicache loading triggered
1643
1644
1645
            new_batch.hicache_consumer_index = (
                self.tree_cache.ready_to_load_host_cache()
            )
1646

1647
        new_batch.prepare_for_extend()
1648

Lianmin Zheng's avatar
Lianmin Zheng committed
1649
        # Mixed-style chunked prefill
1650
1651
        if (
            self.is_mixed_chunk
Lianmin Zheng's avatar
Lianmin Zheng committed
1652
            and not self.running_batch.is_empty()
1653
1654
1655
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1656
1657
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1658
                self.running_batch.prepare_for_decode()
1659
1660
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1661
1662
1663
            self.running_batch = ScheduleBatch(
                reqs=[], batch_is_full=self.running_batch.batch_is_full
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1664
1665
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1666
1667
1668

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1669
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1670
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1671
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1672

1673
1674
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1675
1676
            batch.batch_is_full = False
            return batch
1677

Lianmin Zheng's avatar
Lianmin Zheng committed
1678
        # Check if decode out of memory
1679
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1680
            TEST_RETRACT and batch.batch_size() > 10
1681
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1682
1683
            old_ratio = self.new_token_ratio

1684
            retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
Lianmin Zheng's avatar
Lianmin Zheng committed
1685
            self.new_token_ratio = new_token_ratio
1686

Lianmin Zheng's avatar
Lianmin Zheng committed
1687
            logger.info(
1688
                "KV cache pool is full. Retract requests. "
Lianmin Zheng's avatar
Lianmin Zheng committed
1689
1690
1691
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
1692
            self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
1693
1694
        else:
            self.new_token_ratio = max(
1695
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1696
1697
1698
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1699
        if batch.batch_size() < initial_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1700
            batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1701
1702

        # Update batch tensors
1703
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1704
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1705

1706
1707
1708
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1709
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1710
1711
        self.forward_ct += 1

1712
1713
        # Whether to run the profiler
        self._profile_batch_predicate(batch)
1714
1715
1716
1717
        if self.forward_sleep_time is not None:
            logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
            time.sleep(self.forward_sleep_time)

1718
        # Run forward
1719
        if self.is_generation:
1720
1721
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
1722
1723
1724
1725
1726

                # update the consumer index of hicache to the running batch
                self.tp_worker.set_hicache_consumer(
                    model_worker_batch.hicache_consumer_index
                )
1727
                if self.pp_group.is_last_rank:
1728
                    logits_output, next_token_ids, can_run_cuda_graph = (
1729
1730
1731
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
                else:
1732
                    pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
1733
1734
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
1735
                bid = model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
1736
            else:
1737
1738
1739
                (
                    logits_output,
                    next_token_ids,
1740
                    bid,
1741
                    num_accepted_tokens,
1742
                    can_run_cuda_graph,
1743
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
1744
1745
1746
                bs = batch.batch_size()
                self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
                self.spec_num_total_forward_ct += bs
1747
                self.num_generated_tokens += num_accepted_tokens
1748
1749
1750

            if self.pp_group.is_last_rank:
                batch.output_ids = next_token_ids
1751

1752
1753
1754
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
1755
            if batch.return_logprob or self.spec_algorithm.is_eagle():
1756
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
1757
1758
1759
            else:
                extend_input_len_per_req = None
            if batch.return_logprob:
1760
1761
1762
1763
1764
1765
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_logprob_start_len_per_req = None

1766
            ret = GenerationBatchResult(
1767
1768
1769
1770
1771
1772
1773
                logits_output=logits_output if self.pp_group.is_last_rank else None,
                pp_hidden_states_proxy_tensors=(
                    pp_hidden_states_proxy_tensors
                    if not self.pp_group.is_last_rank
                    else None
                ),
                next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
1774
1775
                extend_input_len_per_req=extend_input_len_per_req,
                extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1776
                bid=bid,
1777
                can_run_cuda_graph=can_run_cuda_graph,
1778
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1779
1780
1781
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1782
1783
1784
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1785
        return ret
Chayenne's avatar
Chayenne committed
1786

1787
1788
1789
1790
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
1791
        launch_done: Optional[threading.Event] = None,
1792
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1793
        if batch.forward_mode.is_decode():
1794
            self.process_batch_result_decode(batch, result, launch_done)
1795
        elif batch.forward_mode.is_extend():
1796
            self.process_batch_result_prefill(batch, result, launch_done)
1797
1798
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1799
                self.tp_worker.resolve_last_batch_result(launch_done)
1800
                self.set_next_batch_sampling_info_done(batch)
1801
        elif batch.forward_mode.is_dummy_first():
1802
            self.set_next_batch_sampling_info_done(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1803

1804
1805
1806
1807
1808
1809
1810
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

1811
1812
    def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
        return self.prepare_mlp_sync_batch_raw(
1813
1814
1815
1816
1817
1818
1819
1820
            local_batch,
            dp_size=self.server_args.dp_size,
            attn_tp_size=self.attn_tp_size,
            tp_cpu_group=self.tp_cpu_group,
            get_idle_batch=self.get_idle_batch,
            disable_cuda_graph=self.server_args.disable_cuda_graph,
            spec_algorithm=self.spec_algorithm,
            speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1821
1822
1823
            enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
            enable_deepep_moe=self.server_args.enable_deepep_moe,
            deepep_mode=DeepEPMode[self.server_args.deepep_mode],
1824
            require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
1825
1826
1827
        )

    @staticmethod
1828
    def prepare_mlp_sync_batch_raw(
1829
1830
1831
1832
1833
1834
1835
1836
        local_batch: ScheduleBatch,
        dp_size,
        attn_tp_size: int,
        tp_cpu_group,
        get_idle_batch,
        disable_cuda_graph: bool,
        spec_algorithm,
        speculative_num_draft_tokens,
1837
1838
1839
        enable_two_batch_overlap: bool,
        enable_deepep_moe: bool,
        deepep_mode: DeepEPMode,
1840
        require_mlp_tp_gather: bool,
1841
    ):
1842
1843
1844
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
1845
            num_tokens_for_logprob = 0
1846
1847
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
1848
            num_tokens_for_logprob = num_tokens
1849
1850
        else:
            num_tokens = local_batch.extend_num_tokens
1851
            num_tokens_for_logprob = sum(
Lianmin Zheng's avatar
Lianmin Zheng committed
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
                [
                    # We should have at least 1 token for sample in every case.
                    max(extend_len - logprob_start_len, 1)
                    for logprob_start_len, extend_len in zip(
                        local_batch.extend_logprob_start_lens, local_batch.extend_lens
                    )
                ]
            )

        if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
            can_cuda_graph = 1
        else:
            can_cuda_graph = 0

        is_extend_in_batch = (
            local_batch.forward_mode.is_extend() if local_batch else False
        )
1869
1870
1871

        tbo_preparer = TboDPAttentionPreparer()

Lianmin Zheng's avatar
Lianmin Zheng committed
1872
1873
1874
1875
        local_info = torch.tensor(
            [
                num_tokens,
                can_cuda_graph,
1876
                num_tokens_for_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
1877
                is_extend_in_batch,
1878
1879
1880
1881
1882
1883
                *tbo_preparer.prepare_all_gather(
                    local_batch,
                    deepep_mode,
                    enable_deepep_moe,
                    enable_two_batch_overlap,
                ),
Lianmin Zheng's avatar
Lianmin Zheng committed
1884
1885
1886
1887
            ],
            dtype=torch.int64,
        )
        global_info = torch.empty(
1888
            (dp_size, attn_tp_size, 6),
Lianmin Zheng's avatar
Lianmin Zheng committed
1889
1890
            dtype=torch.int64,
        )
1891
        torch.distributed.all_gather_into_tensor(
Lianmin Zheng's avatar
Lianmin Zheng committed
1892
1893
            global_info.flatten(),
            local_info,
1894
            group=tp_cpu_group,
1895
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1896
1897
1898
1899
        global_num_tokens = global_info[:, 0, 0].tolist()
        can_cuda_graph = min(global_info[:, 0, 1].tolist())
        global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
        is_extend_in_batch = global_info[:, 0, 3].tolist()
1900

1901
1902
1903
1904
        tbo_split_seq_index, global_forward_mode = tbo_preparer.compute_output(
            global_info[:, :, 4:6]
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1905
        if local_batch is None and max(global_num_tokens) > 0:
1906
            local_batch = get_idle_batch()
1907
1908

        if local_batch is not None:
1909
            # TODO: handle the case when moe_dense_tp_size != 1
1910
            if not require_mlp_tp_gather:
1911
1912
1913
1914
1915
1916
1917
                local_batch.global_num_tokens = [num_tokens]
                local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
            else:
                local_batch.global_num_tokens = global_num_tokens
                local_batch.global_num_tokens_for_logprob = (
                    global_num_tokens_for_logprob
                )
1918
            local_batch.is_extend_in_batch = any(is_extend_in_batch)
1919
1920
            local_batch.tbo_split_seq_index = tbo_split_seq_index
            local_batch.global_forward_mode = global_forward_mode
1921

1922
            # Check forward mode for cuda graph
1923
            if not disable_cuda_graph:
Lianmin Zheng's avatar
Lianmin Zheng committed
1924
                local_batch.can_run_dp_cuda_graph = can_cuda_graph
1925

1926
        return local_batch
1927
1928
1929
1930
1931

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
1932
            self.token_to_kv_pool_allocator,
1933
1934
1935
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
1936
            self.spec_algorithm,
1937
            self.server_args.enable_custom_logit_processor,
1938
1939
1940
1941
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1942
1943
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
1944

1945
        num_ready_reqs = 0
1946
        num_timeout_reqs = 0
1947
1948
        for req in self.grammar_queue:
            try:
1949
1950
1951
                if req.finished():  # It is aborted by AbortReq
                    num_ready_reqs += 1
                    continue
1952
                req.grammar = req.grammar.result(timeout=0.03)
1953
1954
1955
1956
1957
                self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
                if req.grammar is INVALID_GRAMMAR_OBJ:
                    req.set_finish_with_abort(
                        f"Invalid grammar request: {req.grammar_key=}"
                    )
1958
1959
                num_ready_reqs += 1
            except futures._base.TimeoutError:
1960
                req.grammar_wait_ct += 1
1961
1962
                # NOTE(lianmin): this timeout is the waiting time of the above line. It is
                # not the waiting time from it enters the grammar queue.
1963
                if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
1964
                    num_timeout_reqs = 1
1965
1966
                break

1967
        if self.server_args.enable_dp_attention:
1968
1969
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
1970
        else:
1971
1972
1973
1974
1975
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
1976
            tensor = torch.tensor([num_ready_reqs, num_timeout_reqs], dtype=torch.int32)
1977
1978
1979
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
1980
            num_ready_reqs_max, num_timeout_reqs_max = tensor.tolist()
1981

1982
            for i in range(num_ready_reqs, num_ready_reqs_max):
1983
                req = self.grammar_queue[i]
1984
1985
                if req.finished():  # It is aborted by AbortReq
                    continue
1986
                req.grammar = req.grammar.result()
1987
1988
1989
1990
1991
1992
1993
1994
                self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
                if req.grammar is INVALID_GRAMMAR_OBJ:
                    req.set_finish_with_abort(
                        f"Invalid grammar request: {req.grammar_key=}"
                    )
        else:
            num_ready_reqs_max = num_ready_reqs
            num_timeout_reqs_max = num_timeout_reqs
1995

1996
1997
1998
1999
2000
2001
2002
        for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
            req = self.grammar_queue[i]
            req.grammar.cancel()
            error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
            req.set_finish_with_abort(error_msg)
            self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
        num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
2003

2004
        self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
2005
2006
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

2007
2008
2009
2010
2011
2012
2013
    def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
        if batch.next_batch_sampling_info:
            if batch.next_batch_sampling_info.grammars is not None:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
                self.current_stream.synchronize()
            batch.next_batch_sampling_info.sampling_info_done.set()

2014
2015
2016
    def watchdog_thread(self):
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
        self.watchdog_last_forward_ct = 0
2017
        self.watchdog_last_time = time.perf_counter()
2018
2019

        while True:
2020
            current = time.perf_counter()
2021
2022
2023
2024
2025
2026
2027
2028
2029
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if current > self.watchdog_last_time + self.watchdog_timeout:
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)

Lianmin Zheng's avatar
Lianmin Zheng committed
2030
2031
2032
2033
2034
2035
2036
2037
2038
        if not disable_request_logging():
            # Print batch size and memory pool info to check whether there are de-sync issues.
            logger.error(
                f"{self.cur_batch.batch_size()=}, "
                f"{self.cur_batch.reqs=}, "
                f"{self.token_to_kv_pool_allocator.available_size()=}, "
                f"{self.tree_cache.evictable_size()=}, "
            )

2039
        pyspy_dump_schedulers()
Lianmin Zheng's avatar
Lianmin Zheng committed
2040
        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
2041
2042
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
2043
2044

        # Wait for some time so that the parent process can print the error.
2045
2046
2047
        time.sleep(5)
        self.parent_process.send_signal(signal.SIGQUIT)

2048
2049
2050
    def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
        success = self.flush_cache()
        return FlushCacheReqOutput(success=success)
2051

2052
    def flush_cache(self):
2053
        """Flush the memory pool and cache."""
2054
2055
2056
2057
2058
        if (
            len(self.waiting_queue) == 0
            and self.running_batch.is_empty()
            and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
        ):
2059
2060
            self.cur_batch = None
            self.last_batch = None
2061
            self.tree_cache.reset()
2062
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
2063
                self.grammar_backend.reset()
2064
            self.req_to_token_pool.clear()
2065
            self.token_to_kv_pool_allocator.clear()
2066
2067
2068

            if not self.spec_algorithm.is_none():
                self.draft_worker.model_runner.req_to_token_pool.clear()
2069
                self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
2070
2071
2072
2073
2074

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
2075
2076
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
2077
2078
2079
2080
2081
2082
2083
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
2084
                f"#running-req: {len(self.running_batch.reqs)}"
2085
2086
2087
2088
            )
            if_success = False
        return if_success

Liangsheng Yin's avatar
Liangsheng Yin committed
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
    def get_load(self):
        # TODO(lsyin): use dynamically maintained num_waiting_tokens
        load = (
            self.max_total_num_tokens
            - self.token_to_kv_pool_allocator.available_size()
            - self.tree_cache.evictable_size()
        )
        load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            load += sum(
                len(req.origin_input_ids)
                for req in self.disagg_prefill_bootstrap_queue.queue
            )
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            load += sum(
                len(req.req.origin_input_ids)
                for req in self.disagg_decode_prealloc_queue.queue
            )

        return load

2110
2111
2112
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
2113
2114
2115
2116
2117
2118
2119
2120
2121
        ret["memory_usage"] = {
            "weight": round(
                self.tp_worker.worker.model_runner.weight_load_mem_usage, 2
            ),
            "kvcache": round(
                self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
            ),
            "token_capacity": int(self.max_total_num_tokens),
        }
2122
2123
2124
2125
2126
2127

        if not _is_cpu:
            ret["memory_usage"]["cuda_graph"] = round(
                self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
            )

2128
2129
2130
2131
2132
2133
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )
        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
Liangsheng Yin's avatar
Liangsheng Yin committed
2134
2135
2136
2137

        ret["load"] = self.get_load()

        return GetInternalStateReqOutput(internal_state=ret)
2138
2139
2140
2141
2142

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
2143
                "max_micro_batch_size",
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
2154
2155
2156
2157
2158
2159
2160
2161
            elif k == "max_micro_batch_size" and (
                v > self.max_running_requests // self.pp_size or v < 1
            ):
                logging.warning(
                    f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
                )
                if_success = False
                break
2162
2163
2164
2165
2166
2167
2168
2169
2170
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
2171
            logger.info(f"Global server args updated! {global_server_args_dict=}")
2172
2173
2174
2175
2176
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
    def handle_rpc_request(self, recv_req: RpcReqInput):
        # Handle RPC requests
        logger.info(
            f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
        )

        success = True
        exec = None
        try:
            func = getattr(self, recv_req.method)
            func(recv_req.parameters)
        except Exception as e:
            success = False
            exec = e
            logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")

        barrier()
        return RpcReqOutput(success, "" if not exec else str(exec))

    def save_remote_model(self, params):
        url = params["url"]

2199
        worker = self.tp_worker.worker
2200
2201
2202
2203

        worker.model_runner.save_remote_model(url)

    def save_sharded_model(self, params):
2204
        worker = self.tp_worker.worker
2205
2206
2207
2208
2209
2210
2211

        worker.model_runner.save_sharded_model(
            path=params["path"],
            pattern=params["pattern"],
            max_size=params["max_size"],
        )

2212
2213
    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
Lianmin Zheng's avatar
Lianmin Zheng committed
2214
        to_del = []
2215
        for i, req in enumerate(self.waiting_queue):
2216
            if recv_req.abort_all or req.rid.startswith(recv_req.rid):
Lianmin Zheng's avatar
Lianmin Zheng committed
2217
                to_del.append(i)
2218

Lianmin Zheng's avatar
Lianmin Zheng committed
2219
        # Sort in reverse order to avoid index issues when deleting
Lianmin Zheng's avatar
Lianmin Zheng committed
2220
        for i in reversed(to_del):
2221
2222
2223
            # Abort method 1: directly pop from the queue
            # This only works for requests that have not started anything.
            # We still need to send something back to TokenizerManager to clean up the state.
Lianmin Zheng's avatar
Lianmin Zheng committed
2224
            req = self.waiting_queue.pop(i)
Lianmin Zheng's avatar
Lianmin Zheng committed
2225
            self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
2226
            logger.debug(f"Abort queued request. {req.rid=}")
2227

2228
2229
2230
2231
2232
        # Delete the requests in the grammar queue
        for req in self.grammar_queue:
            # Abort method 2: call `set_finish_with_abort`
            # The request will still run one prefill forward pass.
            # In this case, we change the input_ids to be only one token to make this prefill cheap.
2233
            if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2234
                logger.debug(f"Abort grammar queue request. {req.rid=}")
2235
2236
                if req.grammar:
                    req.grammar.cancel()
2237
2238
                req.set_finish_with_abort("Aborted by AbortReq.")

2239
        # Delete requests in the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
2240
2241
2242
2243
2244
2245
        if self.cur_batch is self.running_batch or self.cur_batch is None:
            reqs = self.running_batch.reqs
        else:
            reqs = self.running_batch.reqs + self.cur_batch.reqs

        for req in reqs:
2246
2247
2248
            if not req.finished() and (
                recv_req.abort_all or req.rid.startswith(recv_req.rid)
            ):
2249
2250
2251
                # Abort method 3: set `to_abort=True`
                # The request will still run one decode forward pass.
                # Then we reuse all existing code to clean up the KV cache allocation.
Lianmin Zheng's avatar
Lianmin Zheng committed
2252
2253
                logger.debug(f"Abort running request. {req.rid=}")
                req.to_abort = True
2254

2255
2256
2257
    def _pause_engine(self) -> Tuple[List[Req], int]:
        raise NotImplementedError()

Chayenne's avatar
Chayenne committed
2258
2259
2260
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
2261
        if success:
Stefan He's avatar
Stefan He committed
2262
2263
            flush_cache_success = self.flush_cache()
            assert flush_cache_success, "Cache flush failed after updating weights"
2264
2265
        else:
            logger.error(message)
2266
        return UpdateWeightFromDiskReqOutput(success, message, 0)
2267

2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
    def load_lora_adapter(
        self, recv_req: LoadLoRAAdapterReqInput
    ) -> LoadLoRAAdapterReqOutput:
        """In-place loading a new lora adapter from disk or huggingface."""

        result = self.tp_worker.load_lora_adapter(recv_req)

        if result.success:
            flush_cache_success = self.flush_cache()
            assert flush_cache_success, "Cache flush failed after loading lora adapter."
        else:
            logger.error(result.error_message)
        return result

    def unload_lora_adapter(
        self, recv_req: UnloadLoRAAdapterReqInput
    ) -> UnloadLoRAAdapterReqOutput:
        """Unload the lora adapter."""

        result = self.tp_worker.unload_lora_adapter(recv_req)

        if result.success:
            flush_cache_success = self.flush_cache()
            assert (
                flush_cache_success
            ), "Cache flush failed after unloading LoRA weights"
        else:
            logger.error(result.error_message)
        return result

2298
2299
2300
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
2301
        return InitWeightsUpdateGroupReqOutput(success, message)
2302
2303

    def update_weights_from_distributed(
2304
2305
2306
        self,
        recv_req: UpdateWeightsFromDistributedReqInput,
    ) -> Tuple[bool, str]:
2307
2308
2309
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
2310
2311
2312
            if recv_req.flush_cache:
                flush_cache_success = self.flush_cache()
                assert flush_cache_success, "Cache flush failed after updating weights"
2313
2314
        else:
            logger.error(message)
2315
        return UpdateWeightsFromDistributedReqOutput(success, message)
2316

2317
2318
2319
2320
2321
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        """Update the online model parameter from tensors."""
        success, message = self.tp_worker.update_weights_from_tensor(recv_req)
        # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
        if success:
2322
            if recv_req.flush_cache:
Stefan He's avatar
Stefan He committed
2323
2324
                flush_cache_success = self.flush_cache()
                assert flush_cache_success, "Cache flush failed after updating weights"
2325
2326
        else:
            logger.error(message)
2327
        barrier(group=self.tp_cpu_group)
2328
        return UpdateWeightsFromTensorReqOutput(success, message)
2329

2330
2331
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
2332
        return GetWeightsByNameReqOutput(parameter)
2333

2334
    def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
        tags = recv_req.tags
        import subprocess

        if tags is None:
            tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]

        if GPU_MEMORY_TYPE_KV_CACHE in tags:
            self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
            self.flush_cache()

        if GPU_MEMORY_TYPE_WEIGHTS in tags:
            self.stashed_model_static_state = _export_static_state(
                self.tp_worker.worker.model_runner.model
            )
2349
            torch.distributed.barrier(self.tp_cpu_group)
2350
2351
            self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)

2352
        return ReleaseMemoryOccupationReqOutput()
2353

2354
    def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
2355
2356
2357
2358
2359
2360
        tags = recv_req.tags
        if tags is None or len(tags) == 0:
            tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]

        if GPU_MEMORY_TYPE_WEIGHTS in tags:
            self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
2361
            torch.distributed.barrier(self.tp_cpu_group)
2362
2363
2364
2365
2366
2367
2368
2369
2370
            _import_static_state(
                self.tp_worker.worker.model_runner.model,
                self.stashed_model_static_state,
            )
            del self.stashed_model_static_state

        if GPU_MEMORY_TYPE_KV_CACHE in tags:
            self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)

2371
2372
        return ResumeMemoryOccupationReqOutput()

2373
2374
2375
2376
2377
2378
2379
    def slow_down(self, recv_req: SlowDownReqInput):
        t = recv_req.forward_sleep_time
        if t is not None and t <= 0:
            t = None
        self.forward_sleep_time = t
        return SlowDownReqOutput()

2380
    def profile(self, recv_req: ProfileReq):
2381
        if recv_req.type == ProfileReqType.START_PROFILE:
2382
2383
2384
2385
2386
2387
2388
2389
            if recv_req.profile_by_stage:
                return self.init_profile(
                    recv_req.output_dir,
                    recv_req.num_steps,
                    recv_req.activities,
                    recv_req.with_stack,
                    recv_req.record_shapes,
                    recv_req.profile_by_stage,
2390
                    recv_req.profile_id,
2391
2392
2393
2394
2395
2396
2397
2398
2399
                )
            else:
                self.init_profile(
                    recv_req.output_dir,
                    recv_req.num_steps,
                    recv_req.activities,
                    recv_req.with_stack,
                    recv_req.record_shapes,
                    recv_req.profile_by_stage,
2400
                    recv_req.profile_id,
2401
2402
                )
                return self.start_profile(True)
2403
        else:
2404
2405
            return self.stop_profile()

2406
    def init_profile(
2407
2408
2409
2410
        self,
        output_dir: Optional[str],
        num_steps: Optional[int],
        activities: Optional[List[str]],
2411
2412
        with_stack: Optional[bool],
        record_shapes: Optional[bool],
2413
        profile_by_stage: bool,
2414
        profile_id: str,
2415
2416
    ) -> ProfileReqOutput:
        if self.profile_in_progress:
2417
2418
2419
2420
2421
            return ProfileReqOutput(
                success=False,
                message="Profiling is already in progress. Call /stop_profile first.",
            )

2422
2423
        self.profile_by_stage = profile_by_stage

2424
2425
2426
2427
2428
2429
        if output_dir is None:
            output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
        if activities is None:
            activities = ["CPU", "GPU"]

        self.torch_profiler_output_dir = output_dir
2430
2431
        self.torch_profiler_with_stack = with_stack
        self.torch_profiler_record_shapes = record_shapes
2432
        self.profiler_activities = activities
2433
        self.profile_id = profile_id
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453

        if num_steps:
            self.profile_steps = num_steps
            if self.profile_by_stage:
                self.profiler_target_prefill_ct = num_steps
                self.profiler_target_decode_ct = num_steps
                self.profiler_prefill_ct = 0
                self.profiler_decode_ct = 0
            else:
                self.profiler_target_forward_ct = self.forward_ct + num_steps
            # The caller will be notified when reaching profiler_target_forward_ct
        else:
            self.profiler_target_forward_ct = None

        return ProfileReqOutput(success=True, message="Succeeded")

    def start_profile(
        self, stage: Optional[ForwardMode] = None
    ) -> ProfileReqOutput | None:
        stage_str = f" for {stage.__str__()}" if stage else ""
2454
        logger.info(
2455
            f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
2456
2457
        )

2458
2459
2460
2461
        activities = self.profiler_activities
        with_stack = self.torch_profiler_with_stack
        record_shapes = self.torch_profiler_record_shapes

2462
2463
2464
2465
2466
2467
2468
2469
        activity_map = {
            "CPU": torch.profiler.ProfilerActivity.CPU,
            "GPU": torch.profiler.ProfilerActivity.CUDA,
        }
        torchprof_activities = [
            activity_map[a] for a in activities if a in activity_map
        ]

2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
        if "RPD" in activities:
            from rpdTracerControl import rpdTracerControl

            rpdTracerControl.skipCreate()

            self.rpd_profile_path = os.path.join(
                self.torch_profiler_output_dir,
                "rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
            )

            if self.tp_rank == 0:
                import sqlite3

                from rocpd.schema import RocpdSchema

                if os.path.exists("trace.rpd"):
                    os.unlink("trace.rpd")
                schema = RocpdSchema()
                connection = sqlite3.connect("trace.rpd")
                schema.writeSchema(connection)
                connection.commit()
                del connection
            torch.distributed.barrier(self.tp_cpu_group)

            self.rpd_profiler = rpdTracerControl()
            self.rpd_profiler.setPythonTrace(True)
            self.rpd_profiler.start()
            self.rpd_profiler.rangePush("", "rpd profile range", "")
            self.profile_in_progress = True
        elif torchprof_activities:
2500
2501
            self.torch_profiler = torch.profiler.profile(
                activities=torchprof_activities,
2502
2503
                with_stack=with_stack if with_stack is not None else True,
                record_shapes=record_shapes if record_shapes is not None else False,
2504
2505
            )
            self.torch_profiler.start()
2506
            self.profile_in_progress = True
2507
2508
2509

        if "MEM" in activities:
            torch.cuda.memory._record_memory_history(max_entries=100000)
2510
            self.profile_in_progress = True
2511

2512
2513
2514
        if "CUDA_PROFILER" in activities:
            torch.cuda.cudart().cudaProfilerStart()

2515
        return ProfileReqOutput(success=True, message="Succeeded")
2516

2517
2518
2519
2520
    def stop_profile(
        self, stage: Optional[ForwardMode] = None
    ) -> ProfileReqOutput | None:
        if not self.profile_in_progress:
2521
2522
2523
2524
            return ProfileReqOutput(
                success=False,
                message="Profiling is not in progress. Call /start_profile first.",
            )
2525

2526
2527
2528
        if not Path(self.torch_profiler_output_dir).exists():
            Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)

2529
2530
        stage_suffix = f"-{stage.__str__()}" if stage else ""
        logger.info("Stop profiling" + stage_suffix + "...")
2531
2532
2533
2534
2535
        if self.torch_profiler is not None:
            self.torch_profiler.stop()
            self.torch_profiler.export_chrome_trace(
                os.path.join(
                    self.torch_profiler_output_dir,
2536
                    self.profile_id
2537
2538
2539
                    + f"-TP-{self.tp_rank}"
                    + stage_suffix
                    + ".trace.json.gz",
2540
2541
                )
            )
2542
2543
2544
2545
2546
2547
            torch.distributed.barrier(self.tp_cpu_group)

        if self.rpd_profiler is not None:
            self.rpd_profiler.rangePop()
            self.rpd_profiler.stop()
            self.rpd_profiler.flush()
2548

2549
2550
2551
2552
2553
2554
2555
2556
2557
            torch.distributed.barrier(self.tp_cpu_group)
            if self.tp_rank == 0:
                from sglang.srt.utils import rpd_to_chrome_trace

                rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
            self.rpd_profiler = None
            self.rpd_profiler_path = None

        if self.profiler_activities is not None and "MEM" in self.profiler_activities:
2558
            memory_profile_path = os.path.join(
2559
                self.torch_profiler_output_dir,
2560
2561
2562
2563
                str(time.time())
                + f"-TP-{self.tp_rank}-memory"
                + stage_suffix
                + ".pickle",
2564
2565
2566
2567
            )
            torch.cuda.memory._dump_snapshot(memory_profile_path)
            torch.cuda.memory._record_memory_history(enabled=None)

2568
2569
2570
        if "CUDA_PROFILER" in self.profiler_activities:
            torch.cuda.cudart().cudaProfilerStop()

2571
2572
2573
        logger.info(
            "Profiling done. Traces are saved to: %s",
            self.torch_profiler_output_dir,
2574
        )
2575
        self.torch_profiler = None
2576
2577
2578
2579
2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
2598
        self.profile_in_progress = False

        return ProfileReqOutput(success=True, message="Succeeded.")

    def _profile_batch_predicate(self, batch):
        if self.profile_by_stage:
            if batch.forward_mode.is_prefill():
                if self.profiler_prefill_ct == 0:
                    self.start_profile(batch.forward_mode)
                self.profiler_prefill_ct += 1
                if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
                    if self.profile_in_progress:
                        self.stop_profile(stage=ForwardMode.EXTEND)
            elif batch.forward_mode.is_decode():
                if self.profiler_decode_ct == 0:
                    if self.profile_in_progress:
                        # force trace flush
                        self.stop_profile(ForwardMode.EXTEND)
                    self.start_profile(batch.forward_mode)
                self.profiler_decode_ct += 1
                if self.profiler_decode_ct > self.profiler_target_decode_ct:
                    if self.profile_in_progress:
                        self.stop_profile(stage=ForwardMode.DECODE)
2599
2600
            elif batch.forward_mode.is_idle():
                pass
2601
            else:
2602
                raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}")
2603
2604
2605
2606
2607
2608
2609
        else:
            # Check profiler
            if (
                self.profiler_target_forward_ct
                and self.profiler_target_forward_ct <= self.forward_ct
            ):
                self.stop_profile()
2610

2611
2612
    def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
        if recv_req == ExpertDistributionReq.START_RECORD:
2613
            get_global_expert_distribution_recorder().start_record()
2614
        elif recv_req == ExpertDistributionReq.STOP_RECORD:
2615
            get_global_expert_distribution_recorder().stop_record()
2616
        elif recv_req == ExpertDistributionReq.DUMP_RECORD:
2617
            get_global_expert_distribution_recorder().dump_record()
2618
2619
        else:
            raise ValueError("Unrecognized ExpertDistributionReq value")
2620
        return ExpertDistributionReqOutput()
2621

2622
    def open_session(self, recv_req: OpenSessionReqInput):
2623
2624
2625
2626
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
2627
            return OpenSessionReqOutput(session_id, False)
2628
        elif session_id is None:
2629
            logger.warning("session id is None, cannot open.")
2630
            return OpenSessionReqOutput(session_id, False)
2631
2632
2633
2634
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
2635
            return OpenSessionReqOutput(session_id, True)
2636
2637
2638
2639
2640
2641
2642
2643
2644

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

2645
2646
    def get_print_prefix(self):
        prefix = ""
2647
2648
        if self.attn_dp_rank is not None:
            prefix += f" DP{self.attn_dp_rank}"
2649
2650
2651
2652
2653
2654
        if self.server_args.tp_size > 1:
            prefix += f" TP{self.tp_rank}"
        if self.pp_size > 1:
            prefix += f" PP{self.pp_rank}"
        return prefix

2655
2656
2657
2658
2659
2660
2661
    def _publish_kv_events(self):
        if self.enable_kv_cache_events:
            events = self.tree_cache.take_events()
            if events:
                batch = KVEventBatch(ts=time.time(), events=events)
                self.kv_event_publisher.publish(batch)

2662

2663
2664
2665
2666
def is_health_check_generate_req(recv_req):
    return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")


2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
def _export_static_state(model):
    return dict(
        buffers=[
            (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
        ]
    )


def _import_static_state(model, static_params):
    self_named_buffers = dict(model.named_buffers())
    for name, tensor in static_params["buffers"]:
        self_named_buffers[name][...] = tensor


2681
2682
2683
2684
2685
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
2686
    pp_rank: int,
2687
    dp_rank: Optional[int],
2688
    pipe_writer,
2689
):
2690
    # Generate the prefix
2691
2692
2693
2694
2695
2696
2697
    prefix = ""
    if dp_rank is not None:
        prefix += f" DP{dp_rank}"
    if server_args.tp_size > 1:
        prefix += f" TP{tp_rank}"
    if server_args.pp_size > 1:
        prefix += f" PP{pp_rank}"
2698

2699
    # Config the process
2700
    kill_itself_when_parent_died()
2701
    setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
2702
    faulthandler.enable()
2703
    parent_process = psutil.Process().parent()
2704

2705
2706
2707
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
2708

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
2709
    # Configure the logger
2710
    configure_logger(server_args, prefix=prefix)
2711
    suppress_other_loggers()
2712

2713
    # Set cpu affinity to this gpu process
2714
2715
2716
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

2717
2718
2719
2720
    embedding_cache_size = 100
    if "SGLANG_VLM_CACHE_SIZE_MB" in os.environ:
        embedding_cache_size = int(os.environ["SGLANG_VLM_CACHE_SIZE_MB"])
    init_embedding_cache(embedding_cache_size * 1024 * 1024)
2721
    # Create a scheduler and run the event loop
2722
    try:
2723
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
2724
        pipe_writer.send(
Mick's avatar
Mick committed
2725
2726
2727
2728
2729
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2730
        )
Byron Hsu's avatar
Byron Hsu committed
2731
2732
2733
        disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode

        if disaggregation_mode == DisaggregationMode.NULL:
2734
2735
2736
            if server_args.pp_size > 1:
                scheduler.event_loop_pp()
            elif scheduler.enable_overlap:
Byron Hsu's avatar
Byron Hsu committed
2737
2738
2739
2740
                scheduler.event_loop_overlap()
            else:
                scheduler.event_loop_normal()
        elif disaggregation_mode == DisaggregationMode.PREFILL:
2741
2742
2743
2744
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_prefill()
            else:
                scheduler.event_loop_normal_disagg_prefill()
2745

Byron Hsu's avatar
Byron Hsu committed
2746
        elif disaggregation_mode == DisaggregationMode.DECODE:
2747
2748
2749
2750
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_decode()
            else:
                scheduler.event_loop_normal_disagg_decode()
Byron Hsu's avatar
Byron Hsu committed
2751

2752
    except Exception:
2753
2754
2755
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)