scheduler.py 98.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
import time
23
from collections import defaultdict, deque
Lianmin Zheng's avatar
Lianmin Zheng committed
24
from concurrent import futures
25
from dataclasses import dataclass
26
from http import HTTPStatus
27
from types import SimpleNamespace
28
from typing import Dict, List, Optional, Tuple, Union
29

30
import psutil
31
import setproctitle
32
import torch
33
import zmq
34
from torch.distributed import barrier
35

36
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
37
from sglang.srt.configs.model_config import ModelConfig
38
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
Byron Hsu's avatar
Byron Hsu committed
39
40
41
42
43
from sglang.srt.disaggregation.decode import (
    DecodePreallocQueue,
    DecodeTransferQueue,
    SchedulerDisaggregationDecodeMixin,
)
44
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
Byron Hsu's avatar
Byron Hsu committed
45
46
47
48
49
50
from sglang.srt.disaggregation.prefill import (
    PrefillBootstrapQueue,
    SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
51
    MetadataBuffers,
Byron Hsu's avatar
Byron Hsu committed
52
    ReqToMetadataIdxAllocator,
53
    TransferBackend,
54
    prepare_abort,
Byron Hsu's avatar
Byron Hsu committed
55
)
56
from sglang.srt.distributed import get_pp_group, get_world_group
xm:D's avatar
xm:D committed
57
58
59
60
61
from sglang.srt.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
62
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
63
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
64
65
66
from sglang.srt.managers.expert_distribution import (
    get_global_expert_distribution_recorder,
)
67
68
from sglang.srt.managers.io_struct import (
    AbortReq,
69
    CloseSessionReqInput,
70
    ExpertDistributionReq,
71
    ExpertDistributionReqOutput,
72
73
    FlushCacheReqInput,
    FlushCacheReqOutput,
74
75
    GetInternalStateReq,
    GetInternalStateReqOutput,
76
77
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
78
    HealthCheckOutput,
79
80
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
81
82
    OpenSessionReqInput,
    OpenSessionReqOutput,
83
    ProfileReq,
84
85
    ProfileReqOutput,
    ProfileReqType,
86
87
88
89
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
90
91
    RpcReqInput,
    RpcReqOutput,
92
93
    SetInternalStateReq,
    SetInternalStateReqOutput,
94
95
    SlowDownReqInput,
    SlowDownReqOutput,
96
97
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
98
99
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
100
101
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
102
103
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
104
)
105
from sglang.srt.managers.mm_utils import init_embedding_cache
106
107
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
Mick's avatar
Mick committed
108
    MultimodalInputs,
109
110
    Req,
    ScheduleBatch,
111
    global_server_args_dict,
112
)
113
114
115
116
117
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
118
119
120
from sglang.srt.managers.scheduler_output_processor_mixin import (
    SchedulerOutputProcessorMixin,
)
121
from sglang.srt.managers.session_controller import Session
122
from sglang.srt.managers.tp_worker import TpModelWorker
123
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
124
from sglang.srt.managers.utils import validate_input_length
125
from sglang.srt.mem_cache.chunk_cache import ChunkCache
126
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
127
from sglang.srt.mem_cache.radix_cache import RadixCache
128
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
Lianmin Zheng's avatar
Lianmin Zheng committed
129
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
130
from sglang.srt.reasoning_parser import ReasoningParser
131
from sglang.srt.server_args import PortArgs, ServerArgs
132
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
133
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
134
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
135
from sglang.srt.utils import (
136
    DeepEPMode,
137
    DynamicGradMode,
138
139
    broadcast_pyobj,
    configure_logger,
Lianmin Zheng's avatar
Lianmin Zheng committed
140
    disable_request_logging,
141
    get_available_gpu_memory,
142
    get_bool_env_var,
143
    get_zmq_socket,
Lianmin Zheng's avatar
Lianmin Zheng committed
144
    kill_itself_when_parent_died,
145
    point_to_point_pyobj,
146
    pyspy_dump_schedulers,
147
    set_gpu_proc_affinity,
148
149
150
    set_random_seed,
    suppress_other_loggers,
)
151
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
152
153
154

logger = logging.getLogger(__name__)

155
# Test retract decode for debugging purposes
156
157
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
158
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
159

160

161
162
@dataclass
class GenerationBatchResult:
163
164
165
    logits_output: Optional[LogitsProcessorOutput]
    pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
    next_token_ids: Optional[List[int]]
166
167
    extend_input_len_per_req: List[int]
    extend_logprob_start_len_per_req: List[int]
168
    bid: int
169
    can_run_cuda_graph: bool
170
171
172
173
174
175
176
177


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


Byron Hsu's avatar
Byron Hsu committed
178
179
180
181
182
class Scheduler(
    SchedulerOutputProcessorMixin,
    SchedulerDisaggregationDecodeMixin,
    SchedulerDisaggregationPrefillMixin,
):
183
184
185
186
187
188
189
190
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
191
        pp_rank: int,
192
        dp_rank: Optional[int],
193
194
    ):
        # Parse args
195
        self.server_args = server_args
196
        self.tp_rank = tp_rank
197
        self.pp_rank = pp_rank
198
        self.tp_size = server_args.tp_size
199
200
        self.pp_size = server_args.pp_size
        self.dp_size = server_args.dp_size
201
202
203
        self.schedule_policy = server_args.schedule_policy
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
204
        self.enable_overlap = not server_args.disable_overlap_schedule
205
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
206
        self.enable_metrics = server_args.enable_metrics
207
        self.enable_kv_cache_events = server_args.kv_events_config is not None
208
        self.stream_interval = server_args.stream_interval
209
210
211
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
212
213
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
Lianmin Zheng's avatar
Lianmin Zheng committed
214
        self.page_size = server_args.page_size
215
216
        self.dp_size = server_args.dp_size
        self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
217
218
219
220
221
222
223
224
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

225
226
        # Init inter-process communication
        context = zmq.Context(2)
227
        if self.pp_rank == 0 and self.attn_tp_rank == 0:
228
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
229
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
230
            )
231
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
232
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
233
            )
234

235
            if server_args.skip_tokenizer_init:
236
                # Directly send to the TokenizerManager
237
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
238
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
239
240
                )
            else:
241
                # Send to the DetokenizerManager
242
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
243
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
244
                )
245
246
247
248

            self.recv_from_rpc = get_zmq_socket(
                context, zmq.DEALER, port_args.rpc_ipc_name, False
            )
249
        else:
250
            self.recv_from_tokenizer = None
251
            self.recv_from_rpc = None
252
253
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
254
255

        # Init tokenizer
256
        self.init_tokenizer()
257

258
259
260
261
262
263
264
265
266
        # Set reasoning_parser and think_end_id if --reasoning_parser is enabled
        if self.server_args.reasoning_parser and self.tokenizer:
            reasoning_parser = ReasoningParser(
                model_type=self.server_args.reasoning_parser, stream_reasoning=False
            )
            self.tokenizer.think_end_id = self.tokenizer.encode(
                reasoning_parser.detector.think_end_token, add_special_tokens=False
            )[0]

267
268
269
270
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
271

272
        # Launch a tensor parallel worker
273
        if self.enable_overlap:
274
            TpWorkerClass = TpModelWorkerClient
275
276
        else:
            TpWorkerClass = TpModelWorker
277

278
        self.tp_worker = TpWorkerClass(
279
            server_args=server_args,
280
281
            gpu_id=gpu_id,
            tp_rank=tp_rank,
282
            pp_rank=pp_rank,
283
            dp_rank=dp_rank,
284
            nccl_port=port_args.nccl_port,
285
        )
286

287
        # Launch a draft worker for speculative decoding
288
289
290
291
292
293
294
295
296
297
298
299
300
301
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
        else:
            self.draft_worker = None

302
        # Get token and memory info from the model worker
303
304
305
306
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
307
            self.max_req_len,
308
309
            self.max_req_input_len,
            self.random_seed,
310
            self.device,
311
312
313
314
315
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
316
317
318
319
320
321
322
323
        if global_server_args_dict["max_micro_batch_size"] is None:
            global_server_args_dict["max_micro_batch_size"] = max(
                self.max_running_requests // server_args.pp_size, 1
            )

        self.tp_group = self.tp_worker.get_tp_group()
        self.tp_cpu_group = self.tp_group.cpu_group
        self.attn_tp_group = self.tp_worker.get_attention_tp_group()
324
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
325
326
327
        self.pp_group = get_pp_group()
        self.world_group = get_world_group()

328
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
329
        global_server_args_dict.update(worker_global_server_args_dict)
330
        set_random_seed(self.random_seed)
331

332
        # Print debug info
333
        if tp_rank == 0:
334
335
336
            avail_mem = get_available_gpu_memory(
                self.device, self.gpu_id, empty_cache=False
            )
337
338
339
340
341
            logger.info(
                f"max_total_num_tokens={self.max_total_num_tokens}, "
                f"chunked_prefill_size={server_args.chunked_prefill_size}, "
                f"max_prefill_tokens={self.max_prefill_tokens}, "
                f"max_running_requests={self.max_running_requests}, "
342
343
                f"context_len={self.model_config.context_len}, "
                f"available_gpu_mem={avail_mem:.2f} GB"
344
            )
345

Lianmin Zheng's avatar
Lianmin Zheng committed
346
        # Init memory pool and cache
347
        self.init_memory_pool_and_cache()
348
349
350

        # Init running status
        self.waiting_queue: List[Req] = []
351
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
352
        self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
353
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
354
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
355
        # The last forward batch
356
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
357
358
        self.forward_ct = 0
        self.forward_ct_decode = 0
359
        self.num_generated_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
360
        self.num_prefill_tokens = 0
361
362
        self.last_decode_stats_tic = time.perf_counter()
        self.last_prefill_stats_tic = time.perf_counter()
363
        self.return_health_check_ct = 0
364
        self.current_stream = torch.get_device_module(self.device).current_stream()
365
366
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
367
        self.forward_sleep_time = None
368

369
        # Init session info
370
        self.sessions: Dict[str, Session] = {}
371
372
373

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
374
375
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
376
        self.chunked_req = None
377
378
379
380
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
381
        # Init the grammar backend for constrained generation
382
        self.grammar_queue: List[Req] = []
383
        if not server_args.skip_tokenizer_init:
384
385
386
            self.grammar_backend = create_grammar_backend(
                server_args, self.tokenizer, self.model_config.vocab_size
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
387
388
        else:
            self.grammar_backend = None
389

390
        # Init schedule policy and new token estimation
391
        self.policy = SchedulePolicy(
Lianmin Zheng's avatar
Lianmin Zheng committed
392
393
394
            self.schedule_policy,
            self.tree_cache,
            self.enable_hierarchical_cache,
395
        )
396
397
398
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
399
400
        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
401
402
            * server_args.schedule_conservativeness,
            1.0,
403
        )
404
405
406
407
408
409
410
411
412
413
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
414
415
416
417
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
418
        self.parent_process = psutil.Process().parent()
Lianmin Zheng's avatar
Lianmin Zheng committed
419

420
        # Init memory saver
421
422
423
424
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )

425
        # Init profiler
426
427
        self.torch_profiler = None
        self.torch_profiler_output_dir: Optional[str] = None
428
        self.profiler_activities: Optional[List[str]] = None
429
        self.profiler_id: Optional[str] = None
430
        self.profiler_target_forward_ct: Optional[int] = None
431
432
433
434
435
436
437
438
        self.profiler_target_prefill_ct: Optional[int] = None
        self.profiler_target_decode_ct: Optional[int] = None
        self.profiler_prefill_ct: Optional[int] = None
        self.profiler_decode_ct: Optional[int] = None
        self.profile_by_stage: bool = False
        self.profile_steps: Optional[int] = None
        self.profile_in_progress: bool = False
        self.rpd_profiler = None
439

440
        # Init metrics stats
441
        self.init_metrics()
442
        self.init_kv_events(server_args.kv_events_config)
443

444
445
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
446
447
448
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
449
                (FlushCacheReqInput, self.flush_cache_wrapped),
450
                (AbortReq, self.abort_request),
451
452
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
453
454
455
456
457
458
459
460
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
461
462
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
463
                (SlowDownReqInput, self.slow_down),
464
                (ProfileReq, self.profile),
465
                (GetInternalStateReq, self.get_internal_state),
466
                (SetInternalStateReq, self.set_internal_state),
467
                (RpcReqInput, self.handle_rpc_request),
468
                (ExpertDistributionReq, self.expert_distribution_handle),
469
470
471
            ]
        )

Byron Hsu's avatar
Byron Hsu committed
472
473
474
475
476
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
        self.init_disaggregation()

477
478
    def init_tokenizer(self):
        server_args = self.server_args
Lianmin Zheng's avatar
Lianmin Zheng committed
479

480
        self.model_config = ModelConfig.from_server_args(server_args)
481
        self.is_generation = self.model_config.is_generation
482

483
484
485
486
487
488
489
490
491
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if self.model_config.is_multimodal:
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
492
                    use_fast=not server_args.disable_fast_image_processor,
493
                )
xm:D's avatar
xm:D committed
494
                self.tokenizer = get_tokenizer_from_processor(self.processor)
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )

    def init_memory_pool_and_cache(self):
        server_args = self.server_args

        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            self.tp_worker.get_memory_pool()
        )

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
517
                page_size=self.page_size,
518
519
520
521
522
523
            )
        else:
            if self.enable_hierarchical_cache:
                self.tree_cache = HiRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
524
                    tp_cache_group=self.tp_cpu_group,
525
                    page_size=self.page_size,
526
                    hicache_ratio=server_args.hicache_ratio,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
527
528
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
529
530
531
532
533
                )
            else:
                self.tree_cache = RadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
534
                    page_size=self.page_size,
535
                    disable=server_args.disable_radix_cache,
536
                    enable_kv_cache_events=self.enable_kv_cache_events,
537
538
539
540
541
542
543
544
545
546
547
548
                )

        self.decode_mem_cache_buf_multiplier = (
            1
            if self.spec_algorithm.is_none()
            else (
                server_args.speculative_num_draft_tokens
                + (
                    server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                )
            )
549
        )
550
551
552

    def init_metrics(self):
        self.last_gen_throughput: float = 0.0
Lianmin Zheng's avatar
Lianmin Zheng committed
553
        self.last_input_throughput: float = 0.0
554
555
556
557
558
559
560
561
562
563
564
565
566
567
        self.step_time_dict = defaultdict(list)  # Dict[batch size -> step time]
        self.spec_num_total_accepted_tokens = 0
        self.spec_num_total_forward_ct = 0
        self.cum_spec_accept_length = 0
        self.cum_spec_accept_count = 0
        self.stats = SchedulerStats()
        if self.enable_metrics:
            engine_type = "unified"
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    "engine_type": engine_type,
                },
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
568

569
570
571
572
    def init_kv_events(self, kv_events_config: Optional[str]):
        if self.enable_kv_cache_events:
            self.kv_event_publisher = EventPublisherFactory.create(kv_events_config)

Byron Hsu's avatar
Byron Hsu committed
573
    def init_disaggregation(self):
574
575
576
577
        self.transfer_backend = TransferBackend(
            self.server_args.disaggregation_transfer_backend
        )

Byron Hsu's avatar
Byron Hsu committed
578
579
580
581
582
583
584
        if (
            self.disaggregation_mode == DisaggregationMode.DECODE
        ):  # *2 for the headroom.
            buffer_size = (self.req_to_token_pool.size) * 2
            req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
                buffer_size
            )
585
            self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
Byron Hsu's avatar
Byron Hsu committed
586
587
588

            # The decode requests polling kv cache
            self.disagg_decode_transfer_queue = DecodeTransferQueue(
589
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
590
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
591
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
592
593
                scheduler=self,
                tree_cache=self.tree_cache,
Byron Hsu's avatar
Byron Hsu committed
594
595
596
597
598
599
            )

            # The decode requests pending for pre-allocation
            self.disagg_decode_prealloc_queue = DecodePreallocQueue(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Byron Hsu's avatar
Byron Hsu committed
600
601
602
603
604
                draft_token_to_kv_pool=(
                    None
                    if self.draft_worker is None
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
605
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
606
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
607
608
609
                scheduler=self,
                transfer_queue=self.disagg_decode_transfer_queue,
                tree_cache=self.tree_cache,
610
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
611
612
613
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
614
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
615
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
616
617
618
619

            # Metric for pre-allocation
            self.num_tokens_pre_allocated = 0

Byron Hsu's avatar
Byron Hsu committed
620
621
622
623
624
625
        elif self.disaggregation_mode == DisaggregationMode.PREFILL:
            # *2 for the headroom.
            buffer_size = self.max_running_requests * 2
            req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
                buffer_size
            )
626
            self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
Byron Hsu's avatar
Byron Hsu committed
627

Liangsheng Yin's avatar
Liangsheng Yin committed
628
            self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
Byron Hsu's avatar
Byron Hsu committed
629
                token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
Byron Hsu's avatar
Byron Hsu committed
630
631
632
633
634
                draft_token_to_kv_pool=(
                    None
                    if self.draft_worker is None
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
635
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
636
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
637
638
639
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
640
                gloo_group=self.attn_tp_cpu_group,
641
                transfer_backend=self.transfer_backend,
642
                scheduler=self,
Byron Hsu's avatar
Byron Hsu committed
643
644
            )
            # The prefill requests that are in the middle of kv sending
645
            self.disagg_prefill_inflight_queue: List[Req] = []
Byron Hsu's avatar
Byron Hsu committed
646

647
    @DynamicGradMode()
648
    def event_loop_normal(self):
649
        """A normal scheduler loop."""
650
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
651
652
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
653

654
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
655
            self.cur_batch = batch
656
657
658
659

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
660
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
661
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
662
                self.check_memory()
663
                self.new_token_ratio = self.init_new_token_ratio
664
665

            self.last_batch = batch
666

667
    @DynamicGradMode()
Lianmin Zheng's avatar
Lianmin Zheng committed
668
    def event_loop_overlap(self):
669
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
670
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
671
672
673
674
675
676
677

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
678

Lianmin Zheng's avatar
Lianmin Zheng committed
679
            if batch:
680
                batch.launch_done = threading.Event()
Lianmin Zheng's avatar
Lianmin Zheng committed
681
                result = self.run_batch(batch)
682
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
683

684
                if self.last_batch is None:
685
                    # Create a dummy first batch to start the pipeline for overlap schedule.
686
687
688
689
690
691
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
692
                    self.process_batch_result(tmp_batch, None, batch.launch_done)
693

Lianmin Zheng's avatar
Lianmin Zheng committed
694
            if self.last_batch:
695
                # Process the results of the last batch
696
                tmp_batch, tmp_result = self.result_queue.popleft()
697
698
699
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
700
701
702
703
                # NOTE: we should use current launched batch's launch_done event Instead of the last batch's
                self.process_batch_result(
                    tmp_batch, tmp_result, batch.launch_done if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
704
            elif batch is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
705
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
706
                self.check_memory()
707
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
708
709
710

            self.last_batch = batch

711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
    @DynamicGradMode()
    def event_loop_pp(self):
        """A non-overlap scheduler loop for pipeline parallelism."""
        mbs = [None] * self.pp_size
        last_mbs = [None] * self.pp_size
        self.running_mbs = [
            ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
        ]
        bids = [None] * self.pp_size
        pp_outputs: Optional[PPProxyTensors] = None
        while True:
            server_is_idle = True
            for mb_id in range(self.pp_size):
                self.running_batch = self.running_mbs[mb_id]
                self.last_batch = last_mbs[mb_id]

                recv_reqs = self.recv_requests()
                self.process_input_requests(recv_reqs)
                mbs[mb_id] = self.get_next_batch_to_run()
                self.running_mbs[mb_id] = self.running_batch

                self.cur_batch = mbs[mb_id]
                if self.cur_batch:
                    server_is_idle = False
                    result = self.run_batch(self.cur_batch)

737
                # (last rank) send the outputs to the next step
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
                if self.pp_group.is_last_rank:
                    if self.cur_batch:
                        next_token_ids, bids[mb_id] = (
                            result.next_token_ids,
                            result.bid,
                        )
                        pp_outputs = PPProxyTensors(
                            {
                                "next_token_ids": next_token_ids,
                            }
                        )
                        # send the output from the last round to let the next stage worker run post processing
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                # receive outputs and post-process (filter finished reqs) the coming microbatch
                next_mb_id = (mb_id + 1) % self.pp_size
                next_pp_outputs = None
                if mbs[next_mb_id] is not None:
                    next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
                        self.pp_group.recv_tensor_dict(
                            all_gather_group=self.attn_tp_group
                        )
                    )
                    mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
                    output_result = GenerationBatchResult(
                        logits_output=None,
                        pp_hidden_states_proxy_tensors=None,
                        next_token_ids=next_pp_outputs["next_token_ids"],
                        extend_input_len_per_req=None,
                        extend_logprob_start_len_per_req=None,
                        bid=bids[next_mb_id],
772
                        can_run_cuda_graph=result.can_run_cuda_graph,
773
774
775
776
                    )
                    self.process_batch_result(mbs[next_mb_id], output_result)
                    last_mbs[next_mb_id] = mbs[next_mb_id]

777
                # (not last rank)
778
779
780
                if not self.pp_group.is_last_rank:
                    if self.cur_batch:
                        bids[mb_id] = result.bid
781
782
                    # carry the outputs to the next stage
                    # send the outputs from the last round to let the next stage worker run post processing
783
784
785
786
787
788
789
                    if pp_outputs:
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                    # send out reqs to the next stage
790
                    dp_offset = self.attn_dp_rank * self.attn_tp_size
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
                    if self.attn_tp_rank == 0:
                        point_to_point_pyobj(
                            recv_reqs,
                            self.pp_rank * self.tp_size + dp_offset,
                            self.world_group.cpu_group,
                            self.pp_rank * self.tp_size + dp_offset,
                            (self.pp_rank + 1) * self.tp_size + dp_offset,
                        )

                    # send out proxy tensors to the next stage
                    if self.cur_batch:
                        self.pp_group.send_tensor_dict(
                            result.pp_hidden_states_proxy_tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                pp_outputs = next_pp_outputs

            # When the server is idle, self-check and re-init some states
            if server_is_idle:
                self.check_memory()
                self.new_token_ratio = self.init_new_token_ratio

814
815
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
        if self.pp_rank == 0:
            if self.attn_tp_rank == 0:
                recv_reqs = []

                while True:
                    try:
                        recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_req)

                while True:
                    try:
                        recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_rpc)
            else:
                recv_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
835
        else:
836
            if self.attn_tp_rank == 0:
837
                dp_offset = self.attn_dp_rank * self.attn_tp_size
838
839
840
841
842
843
844
845
846
                recv_reqs = point_to_point_pyobj(
                    [],
                    self.pp_rank * self.tp_size + dp_offset,
                    self.world_group.cpu_group,
                    (self.pp_rank - 1) * self.tp_size + dp_offset,
                    self.pp_rank * self.tp_size + dp_offset,
                )
            else:
                recv_reqs = None
847

848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                work_reqs = broadcast_pyobj(
                    work_reqs,
871
                    self.attn_tp_group.rank,
872
                    self.attn_tp_cpu_group,
873
                    src=self.attn_tp_group.ranks[0],
874
875
876
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
877
878
879
880
                    control_reqs,
                    self.tp_group.rank,
                    self.tp_cpu_group,
                    src=self.tp_group.ranks[0],
881
882
883
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
884
885
886
887
888
889
            recv_reqs = broadcast_pyobj(
                recv_reqs,
                self.tp_group.rank,
                self.tp_cpu_group,
                src=self.tp_group.ranks[0],
            )
890
891
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
892
    def process_input_requests(self, recv_reqs: List):
893
        for recv_req in recv_reqs:
894
895
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
Lianmin Zheng's avatar
Lianmin Zheng committed
896
                self.chunked_req is not None or not self.running_batch.is_empty()
897
898
899
900
            ):
                self.return_health_check_ct += 1
                continue

901
            output = self._request_dispatcher(recv_req)
902
            if output is not None:
903
904
905
906
907
                if isinstance(output, RpcReqOutput):
                    if self.recv_from_rpc is not None:
                        self.recv_from_rpc.send_pyobj(output)
                else:
                    self.send_to_tokenizer.send_pyobj(output)
908
909
910
911
912

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
913
        # Create a new request
914
915
916
917
918
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
919
920
921
922
923
924
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

925
926
927
928
            if recv_req.bootstrap_port is None:
                # Use default bootstrap port
                recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port

929
930
931
932
933
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
934
935
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
936
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
937
                stream=recv_req.stream,
938
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
939
                input_embeds=recv_req.input_embeds,
Lianmin Zheng's avatar
Lianmin Zheng committed
940
                custom_logit_processor=recv_req.custom_logit_processor,
941
                return_hidden_states=recv_req.return_hidden_states,
942
                eos_token_ids=self.model_config.hf_eos_token_id,
943
                bootstrap_host=recv_req.bootstrap_host,
944
                bootstrap_port=recv_req.bootstrap_port,
945
                bootstrap_room=recv_req.bootstrap_room,
946
947
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
948

949
950
951
952
953
954
955
956
957
958
959
960
            if self.disaggregation_mode != DisaggregationMode.NULL:
                # Invalid request for disaggregated mode
                if recv_req.bootstrap_room is None:
                    error_message = (
                        f"Invalid request: Disaggregated request received without "
                        f"boostrap room id. {req.rid=}"
                    )
                    logger.error(error_message)
                    prepare_abort(req, error_message)
                    self.stream_output([req], req.return_logprob)
                    return

961
962
963
964
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
965
                req.finished_reason = FINISH_ABORT(
966
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
967
                )
968
                self._add_request_to_queue(req)
969
970
                return
        else:
971
972
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
973
            req = session.create_req(recv_req, self.tokenizer)
974
            if isinstance(req.finished_reason, FINISH_ABORT):
975
                self._add_request_to_queue(req)
976
                return
977

978
        # Handle multimodal inputs
Mick's avatar
Mick committed
979
980
        if recv_req.mm_inputs is not None:
            image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
981
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
982
            req.origin_input_ids = self.pad_input_ids_func(
983
                req.origin_input_ids, image_inputs
984
            )
985
            req.extend_image_inputs(image_inputs)
986

987
            if len(req.origin_input_ids) >= self.max_req_input_len:
988
                error_msg = (
989
                    "Multimodal prompt is too long after expanding multimodal tokens. "
990
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
991
                )
992
                logger.error(error_msg)
993
                req.origin_input_ids = [0]
Mick's avatar
Mick committed
994
                req.multimodal_inputs = None
995
                req.sampling_params.max_new_tokens = 0
996
                req.finished_reason = FINISH_ABORT(
997
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
998
                )
999
                self._add_request_to_queue(req)
1000
1001
                return

1002
1003
1004
1005
1006
1007
1008
        # Validate prompts length
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
1009
1010
            req.origin_input_ids = [0]
            req.sampling_params.max_new_tokens = 0
1011
            self._add_request_to_queue(req)
1012
            return
1013

1014
        # Copy more attributes
1015
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
1016
1017
1018
1019
1020
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(req.origin_input_ids) - 1
        else:
            req.logprob_start_len = recv_req.logprob_start_len

1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
        if req.logprob_start_len >= len(req.origin_input_ids):
            req.finished_reason = FINISH_ABORT(
                f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
                HTTPStatus.BAD_REQUEST,
                "BadRequestError",
            )
            req.logprob_start_len = len(req.origin_input_ids) - 1
            self._add_request_to_queue(req)
            return

1031
1032
1033
1034
1035
1036
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
1037
            self.max_req_len - len(req.origin_input_ids) - 1,
1038
1039
        )

1040
1041
1042
1043
1044
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
1045
            or req.sampling_params.ebnf is not None
1046
            or req.sampling_params.structural_tag is not None
1047
1048
1049
1050
1051
1052
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
1053
1054
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
1055
1056
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
1057

1058
1059
1060
1061
1062
            value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
            req.grammar = value

            if not cache_hit:
                req.grammar_key = key
1063
1064
1065
                add_to_grammar_queue = True

        if add_to_grammar_queue:
1066
            req.queue_time_start = time.perf_counter()
1067
1068
            self.grammar_queue.append(req)
        else:
1069
1070
1071
            self._add_request_to_queue(req)

    def _add_request_to_queue(self, req: Req):
1072
        req.queue_time_start = time.perf_counter()
Byron Hsu's avatar
Byron Hsu committed
1073
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
Liangsheng Yin's avatar
Liangsheng Yin committed
1074
            self.disagg_prefill_bootstrap_queue.add(req)
Byron Hsu's avatar
Byron Hsu committed
1075
1076
1077
1078
1079
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.add(req)
        else:
            self.waiting_queue.append(req)

1080
1081
1082
1083
1084
    def _extend_requests_to_queue(self, reqs: List[Req]):
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            self.disagg_prefill_bootstrap_queue.extend(reqs)
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            # If this is a decode server, we put the request to the decode pending prealloc queue
Byron Hsu's avatar
Byron Hsu committed
1085
1086
1087
            self.disagg_decode_prealloc_queue.extend(reqs)
        else:
            self.waiting_queue.extend(reqs)
1088
1089
1090

    def handle_embedding_request(
        self,
1091
        recv_req: TokenizedEmbeddingReqInput,
1092
1093
1094
1095
1096
1097
1098
1099
1100
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

1101
1102
        # Handle multimodal inputs
        if recv_req.image_inputs is not None:
Mick's avatar
Mick committed
1103
            image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
            req.origin_input_ids = self.pad_input_ids_func(
                req.origin_input_ids, image_inputs
            )
            req.extend_image_inputs(image_inputs)

            if len(req.origin_input_ids) >= self.max_req_input_len:
                error_msg = (
                    "Multimodal prompt is too long after expanding multimodal tokens. "
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                )
                logger.error(error_msg)
                req.origin_input_ids = [0]
Mick's avatar
Mick committed
1117
                req.multimodal_inputs = None
1118
1119
1120
1121
                req.sampling_params.max_new_tokens = 0
                req.finished_reason = FINISH_ABORT(
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
                )
1122
                req.queue_time_start = time.perf_counter()
1123
1124
1125
                self.waiting_queue.append(req)
                return

1126
        # Validate prompts length
1127
        error_msg = validate_input_length(
1128
1129
1130
1131
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
1132
        if error_msg:
1133
            self._add_request_to_queue(req)
1134
            return
1135

1136
1137
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
1138
        self._add_request_to_queue(req)
1139

1140
1141
1142
1143
    def log_prefill_stats(
        self,
        adder: PrefillAdder,
        can_run_list: List[Req],
1144
        running_bs: int,
1145
    ):
1146
1147
        gap_latency = time.perf_counter() - self.last_prefill_stats_tic
        self.last_prefill_stats_tic = time.perf_counter()
Lianmin Zheng's avatar
Lianmin Zheng committed
1148
1149
1150
        self.last_input_throughput = self.num_prefill_tokens / gap_latency
        self.num_prefill_tokens = 0

1151
        num_used = self.max_total_num_tokens - (
1152
1153
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
1154
1155
        )

1156
        num_new_seq = len(can_run_list)
1157
        f = (
1158
            f"Prefill batch. "
1159
            f"#new-seq: {num_new_seq}, "
1160
1161
1162
1163
1164
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
1165
1166
1167
1168
1169
1170
1171
1172

        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
            f += f"#queue-req: {len(self.waiting_queue)}, "
            f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)} "
        else:
            f += f"#queue-req: {len(self.waiting_queue)}"

1173
        logger.info(f)
1174
1175

        if self.enable_metrics:
1176
1177
1178
            cache_hit_rate = adder.log_hit_tokens / (
                adder.log_input_tokens + adder.log_hit_tokens
            )
1179
1180
1181
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
1182
1183
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.stats.cache_hit_rate = cache_hit_rate
1184
1185
1186
1187
1188
1189

            total_queue_latency = 0
            for req in can_run_list:
                total_queue_latency += req.queue_time_end - req.queue_time_start
            self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq

1190
            self.metrics_collector.log_stats(self.stats)
1191
        self._publish_kv_events()
1192

1193
1194
1195
    def log_decode_stats(
        self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
    ):
1196
1197
        batch = running_batch or self.running_batch

1198
1199
        gap_latency = time.perf_counter() - self.last_decode_stats_tic
        self.last_decode_stats_tic = time.perf_counter()
1200
1201
        self.last_gen_throughput = self.num_generated_tokens / gap_latency
        self.num_generated_tokens = 0
1202
        num_running_reqs = len(batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1203
        num_used = self.max_total_num_tokens - (
1204
1205
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1206
        )
1207
1208
1209
1210
1211

        if RECORD_STEP_TIME:
            self.step_time_dict[num_running_reqs].append(
                gap_latency / self.server_args.decode_log_interval
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1212

Liangsheng Yin's avatar
Liangsheng Yin committed
1213
1214
1215
1216
1217
1218
1219
        msg = (
            f"Decode batch. "
            f"#running-req: {num_running_reqs}, "
            f"#token: {num_used}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
        )

1220
        if self.spec_algorithm.is_none():
1221
            spec_accept_length = 0
1222
        else:
1223
            spec_accept_length = (
1224
1225
                self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
            )
1226
1227
            self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
            self.cum_spec_accept_count += self.spec_num_total_forward_ct
1228
            self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
1229
1230
1231
1232
1233
1234
            msg += f"accept len: {spec_accept_length:.2f}, "

        if self.disaggregation_mode == DisaggregationMode.DECODE:
            msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "

        msg += (
1235
            f"cuda graph: {can_run_cuda_graph}, "
Liangsheng Yin's avatar
Liangsheng Yin committed
1236
1237
1238
            f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
            f"#queue-req: {len(self.waiting_queue)}"
        )
1239
1240

        logger.info(msg)
1241
1242
1243
1244
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
1245
1246
            self.stats.cache_hit_rate = 0.0
            self.stats.gen_throughput = self.last_gen_throughput
1247
            self.stats.num_queue_reqs = len(self.waiting_queue)
1248
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1249
            self.stats.spec_accept_length = spec_accept_length
1250
            self.metrics_collector.log_stats(self.stats)
1251
        self._publish_kv_events()
1252

Lianmin Zheng's avatar
Lianmin Zheng committed
1253
1254
    def check_memory(self):
        available_size = (
1255
1256
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1257
        )
1258
1259
1260
1261
1262
1263
1264
        protected_size = self.tree_cache.protected_size()
        memory_leak = available_size != (
            self.max_total_num_tokens
            if not self.enable_hierarchical_cache
            else self.max_total_num_tokens - protected_size
        )
        if memory_leak:
1265
            msg = (
1266
                "token_to_kv_pool_allocator memory leak detected! "
1267
                f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1268
1269
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1270
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1271
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1272
1273

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
1274
            msg = (
1275
                "req_to_token_pool memory leak detected!"
1276
1277
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1278
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1279
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1280

1281
1282
1283
        if (
            self.enable_metrics
            and self.attn_tp_rank == 0
1284
            and time.perf_counter() > self.metrics_collector.last_log_time + 30
1285
1286
1287
        ):
            # During idle time, also collect metrics every 30 seconds.
            num_used = self.max_total_num_tokens - (
1288
                self.token_to_kv_pool_allocator.available_size()
1289
1290
                + self.tree_cache.evictable_size()
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1291
            num_running_reqs = len(self.running_batch.reqs)
1292
1293
1294
1295
1296
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
1297
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1298
            self.metrics_collector.log_stats(self.stats)
1299
        self._publish_kv_events()
1300

1301
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1302
        # Merge the prefill batch into the running batch
1303
1304
1305
1306
1307
1308
1309
1310
        chunked_req_to_exclude = set()
        if self.chunked_req:
            # Move the chunked request out of the batch so that we can merge
            # only finished requests to running_batch.
            chunked_req_to_exclude.add(self.chunked_req)
            self.tree_cache.cache_unfinished_req(self.chunked_req)
            # chunked request keeps its rid but will get a new req_pool_idx
            self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
Lianmin Zheng's avatar
Lianmin Zheng committed
1311
        if self.last_batch and self.last_batch.forward_mode.is_extend():
1312
1313
1314
1315
            if self.last_batch.chunked_req is not None:
                # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
                # We need to discard it.
                chunked_req_to_exclude.add(self.last_batch.chunked_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1316

1317
            # Filter batch
1318
            last_bs = self.last_batch.batch_size()
1319
1320
1321
            self.last_batch.filter_batch(
                chunked_req_to_exclude=list(chunked_req_to_exclude)
            )
1322
            if self.last_batch.batch_size() < last_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1323
                self.running_batch.batch_is_full = False
1324

1325
            # Merge the new batch into the running batch
1326
            if not self.last_batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1327
                if self.running_batch.is_empty():
1328
1329
                    self.running_batch = self.last_batch
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1330
                    # Merge running_batch with prefill batch
1331
                    self.running_batch.merge_batch(self.last_batch)
1332

1333
1334
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
1335
1336
1337
1338
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
1339
            if not self.running_batch.is_empty():
1340
                self.running_batch = self.update_running_batch(self.running_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1341
1342
1343
                ret = self.running_batch if not self.running_batch.is_empty() else None
            else:
                ret = None
1344

1345
        # Handle DP attention
1346
        if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
Lianmin Zheng's avatar
Lianmin Zheng committed
1347
            ret, _ = self.prepare_dp_attn_batch(ret)
1348
1349

        return ret
1350

1351
1352
1353
1354
1355
1356
    def get_num_allocatable_reqs(self, running_bs):
        res = global_server_args_dict["max_micro_batch_size"] - running_bs
        if self.pp_size > 1:
            res = min(res, self.req_to_token_pool.available_size())
        return res

Lianmin Zheng's avatar
Lianmin Zheng committed
1357
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
1358
        # Check if the grammar is ready in the grammar queue
1359
        if self.grammar_queue:
1360
            self.move_ready_grammar_requests()
1361

Lianmin Zheng's avatar
Lianmin Zheng committed
1362
1363
        # Handle the cases where prefill is not allowed
        if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1364
            self.running_batch.batch_is_full or len(self.waiting_queue) == 0
1365
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1366
1367
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1368
        running_bs = len(self.running_batch.reqs)
1369
        # Ignore the check if self.chunked_req is not None.
1370
1371
1372
1373
1374
        # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
        # as the space for the chunked request has just been released.
        # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
        # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
        if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
Lianmin Zheng's avatar
Lianmin Zheng committed
1375
            self.running_batch.batch_is_full = True
1376
1377
            return None

1378
1379
1380
1381
1382
        if self.enable_hierarchical_cache:
            # check for completion of hierarchical cache activities to release memory
            self.tree_cache.writing_check()
            self.tree_cache.loading_check()

1383
1384
1385
        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
1386
        # Prefill policy
1387
1388
        adder = PrefillAdder(
            self.tree_cache,
1389
            self.token_to_kv_pool_allocator,
1390
1391
1392
1393
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1394
            running_bs if self.is_mixed_chunk else 0,
1395
1396
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1397
        if self.chunked_req is not None:
1398
1399
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1400

Lianmin Zheng's avatar
Lianmin Zheng committed
1401
        if self.lora_paths:
Lianmin Zheng's avatar
Lianmin Zheng committed
1402
1403
            lora_set = set([req.lora_path for req in self.running_batch.reqs])

1404
        # Get requests from the waiting queue to a new prefill batch
1405
1406
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1407
                self.lora_paths
1408
1409
1410
1411
1412
1413
1414
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1415
                self.running_batch.batch_is_full = True
1416
1417
                break

1418
            if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
Lianmin Zheng's avatar
Lianmin Zheng committed
1419
                self.running_batch.batch_is_full = True
1420
                break
1421

Byron Hsu's avatar
Byron Hsu committed
1422
1423
1424
1425
1426
1427
1428
            if self.disaggregation_mode == DisaggregationMode.PREFILL:
                # In prefill mode, prealloc queue and transfer queue can also take memory,
                # so we need to check if the available size for the actual available size.
                if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
                    self.running_batch.batch_is_full = True
                    break

1429
1430
1431
1432
            req.init_next_round_input(
                None if prefix_computed else self.tree_cache,
                self.enable_hierarchical_cache,
            )
1433

1434
1435
1436
            res = adder.add_one_req(
                req, self.chunked_req, self.enable_hierarchical_cache
            )
1437

1438
1439
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1440
1441
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
Lianmin Zheng's avatar
Lianmin Zheng committed
1442
1443
                        self.running_batch.batch_is_full = len(
                            adder.can_run_list
1444
                        ) > 0 or (not self.running_batch.is_empty())
1445
                    else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1446
                        self.running_batch.batch_is_full = True
1447
1448
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1449
        # Update waiting queue
1450
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1451
1452
        if len(can_run_list) == 0:
            return None
1453
1454
1455
1456

        if self.enable_metrics:
            # only record queue time when enable_metrics is True to avoid overhead
            for req in can_run_list:
1457
                req.queue_time_end = time.perf_counter()
1458

Lianmin Zheng's avatar
Lianmin Zheng committed
1459
1460
1461
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1462

1463
        if self.enable_hierarchical_cache:
1464
            self.tree_cache.ready_to_load_cache()
1465

1466
1467
1468
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1469

1470
1471
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1472

1473
        # Print stats
1474
        if self.attn_tp_rank == 0:
1475
            self.log_prefill_stats(adder, can_run_list, running_bs)
1476

Lianmin Zheng's avatar
Lianmin Zheng committed
1477
        # Create a new batch
1478
1479
1480
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
1481
            self.token_to_kv_pool_allocator,
1482
            self.tree_cache,
1483
            self.model_config,
1484
            self.enable_overlap,
1485
            self.spec_algorithm,
1486
            self.server_args.enable_custom_logit_processor,
1487
            chunked_req=self.chunked_req,
1488
        )
1489
        new_batch.prepare_for_extend()
1490

Lianmin Zheng's avatar
Lianmin Zheng committed
1491
        # Mixed-style chunked prefill
1492
1493
        if (
            self.is_mixed_chunk
Lianmin Zheng's avatar
Lianmin Zheng committed
1494
            and not self.running_batch.is_empty()
1495
1496
1497
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1498
1499
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1500
                self.running_batch.prepare_for_decode()
1501
1502
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1503
1504
1505
            self.running_batch = ScheduleBatch(
                reqs=[], batch_is_full=self.running_batch.batch_is_full
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1506
1507
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1508
1509
1510

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1511
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1512
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1513
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1514

1515
1516
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1517
1518
            batch.batch_is_full = False
            return batch
1519

Lianmin Zheng's avatar
Lianmin Zheng committed
1520
        # Check if decode out of memory
1521
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1522
            TEST_RETRACT and batch.batch_size() > 10
1523
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1524
1525
            old_ratio = self.new_token_ratio

1526
            retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
Lianmin Zheng's avatar
Lianmin Zheng committed
1527
            self.new_token_ratio = new_token_ratio
1528

Lianmin Zheng's avatar
Lianmin Zheng committed
1529
            logger.info(
1530
                "KV cache pool is full. Retract requests. "
Lianmin Zheng's avatar
Lianmin Zheng committed
1531
1532
1533
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
1534
            self._extend_requests_to_queue(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1535
1536
        else:
            self.new_token_ratio = max(
1537
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1538
1539
1540
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1541
        if batch.batch_size() < initial_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1542
            batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1543
1544

        # Update batch tensors
1545
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1546
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1547

1548
1549
1550
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1551
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1552
1553
        self.forward_ct += 1

1554
1555
        # Whether to run the profiler
        self._profile_batch_predicate(batch)
1556
1557
1558
1559
        if self.forward_sleep_time is not None:
            logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
            time.sleep(self.forward_sleep_time)

1560
        # Run forward
1561
        if self.is_generation:
1562
1563
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
1564
                if self.pp_group.is_last_rank:
1565
                    logits_output, next_token_ids, can_run_cuda_graph = (
1566
1567
1568
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
                else:
1569
                    pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
1570
1571
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
1572
                bid = model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
1573
            else:
1574
1575
1576
                (
                    logits_output,
                    next_token_ids,
1577
                    bid,
1578
                    num_accepted_tokens,
1579
                    can_run_cuda_graph,
1580
1581
1582
1583
1584
1585
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
                self.spec_num_total_accepted_tokens += (
                    num_accepted_tokens + batch.batch_size()
                )
                self.spec_num_total_forward_ct += batch.batch_size()
                self.num_generated_tokens += num_accepted_tokens
1586
1587
1588

            if self.pp_group.is_last_rank:
                batch.output_ids = next_token_ids
1589

1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
            if batch.return_logprob:
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_input_len_per_req = None
                extend_logprob_start_len_per_req = None

1602
            ret = GenerationBatchResult(
1603
1604
1605
1606
1607
1608
1609
                logits_output=logits_output if self.pp_group.is_last_rank else None,
                pp_hidden_states_proxy_tensors=(
                    pp_hidden_states_proxy_tensors
                    if not self.pp_group.is_last_rank
                    else None
                ),
                next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
1610
1611
                extend_input_len_per_req=extend_input_len_per_req,
                extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1612
                bid=bid,
1613
                can_run_cuda_graph=can_run_cuda_graph,
1614
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1615
1616
1617
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1618
1619
1620
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1621
        return ret
Chayenne's avatar
Chayenne committed
1622

1623
1624
1625
1626
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
1627
        launch_done: Optional[threading.Event] = None,
1628
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1629
        if batch.forward_mode.is_decode():
1630
            self.process_batch_result_decode(batch, result, launch_done)
1631
        elif batch.forward_mode.is_extend():
1632
            self.process_batch_result_prefill(batch, result, launch_done)
1633
1634
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1635
                self.tp_worker.resolve_last_batch_result(launch_done)
1636
                self.set_next_batch_sampling_info_done(batch)
1637
        elif batch.forward_mode.is_dummy_first():
1638
            self.set_next_batch_sampling_info_done(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1639

1640
1641
1642
1643
1644
1645
1646
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

1647
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
1648
1649
1650
1651
        return self.prepare_dp_attn_batch_raw(
            local_batch,
            dp_size=self.server_args.dp_size,
            attn_tp_size=self.attn_tp_size,
1652
            moe_dense_tp_size=self.server_args.moe_dense_tp_size,
1653
1654
1655
1656
1657
            tp_cpu_group=self.tp_cpu_group,
            get_idle_batch=self.get_idle_batch,
            disable_cuda_graph=self.server_args.disable_cuda_graph,
            spec_algorithm=self.spec_algorithm,
            speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1658
1659
1660
            enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
            enable_deepep_moe=self.server_args.enable_deepep_moe,
            deepep_mode=DeepEPMode[self.server_args.deepep_mode],
1661
1662
1663
1664
1665
1666
1667
        )

    @staticmethod
    def prepare_dp_attn_batch_raw(
        local_batch: ScheduleBatch,
        dp_size,
        attn_tp_size: int,
1668
        moe_dense_tp_size: Optional[int],
1669
1670
1671
1672
1673
        tp_cpu_group,
        get_idle_batch,
        disable_cuda_graph: bool,
        spec_algorithm,
        speculative_num_draft_tokens,
1674
1675
1676
        enable_two_batch_overlap: bool,
        enable_deepep_moe: bool,
        deepep_mode: DeepEPMode,
1677
    ):
1678
1679
1680
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
1681
            num_tokens_for_logprob = 0
1682
1683
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
1684
1685
            if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
                num_tokens = num_tokens * speculative_num_draft_tokens
1686
            num_tokens_for_logprob = num_tokens
1687
1688
        else:
            num_tokens = local_batch.extend_num_tokens
1689
            num_tokens_for_logprob = sum(
Lianmin Zheng's avatar
Lianmin Zheng committed
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
                [
                    # We should have at least 1 token for sample in every case.
                    max(extend_len - logprob_start_len, 1)
                    for logprob_start_len, extend_len in zip(
                        local_batch.extend_logprob_start_lens, local_batch.extend_lens
                    )
                ]
            )

        if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
            can_cuda_graph = 1
        else:
            can_cuda_graph = 0

1704
        if not spec_algorithm.is_none():
1705
            # TODO(sang): Support cuda graph when idle batch is there.
Lianmin Zheng's avatar
Lianmin Zheng committed
1706
1707
            if local_batch is None or local_batch.forward_mode.is_idle():
                can_cuda_graph = 0
1708

Lianmin Zheng's avatar
Lianmin Zheng committed
1709
1710
1711
        is_extend_in_batch = (
            local_batch.forward_mode.is_extend() if local_batch else False
        )
1712
1713
1714

        tbo_preparer = TboDPAttentionPreparer()

Lianmin Zheng's avatar
Lianmin Zheng committed
1715
1716
1717
1718
        local_info = torch.tensor(
            [
                num_tokens,
                can_cuda_graph,
1719
                num_tokens_for_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
1720
                is_extend_in_batch,
1721
1722
1723
1724
1725
1726
                *tbo_preparer.prepare_all_gather(
                    local_batch,
                    deepep_mode,
                    enable_deepep_moe,
                    enable_two_batch_overlap,
                ),
Lianmin Zheng's avatar
Lianmin Zheng committed
1727
1728
1729
1730
            ],
            dtype=torch.int64,
        )
        global_info = torch.empty(
1731
            (dp_size, attn_tp_size, 6),
Lianmin Zheng's avatar
Lianmin Zheng committed
1732
1733
            dtype=torch.int64,
        )
1734
        torch.distributed.all_gather_into_tensor(
Lianmin Zheng's avatar
Lianmin Zheng committed
1735
1736
            global_info.flatten(),
            local_info,
1737
            group=tp_cpu_group,
1738
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1739
1740
1741
1742
        global_num_tokens = global_info[:, 0, 0].tolist()
        can_cuda_graph = min(global_info[:, 0, 1].tolist())
        global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
        is_extend_in_batch = global_info[:, 0, 3].tolist()
1743

1744
1745
1746
1747
        tbo_split_seq_index, global_forward_mode = tbo_preparer.compute_output(
            global_info[:, :, 4:6]
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1748
        if local_batch is None and max(global_num_tokens) > 0:
1749
            local_batch = get_idle_batch()
1750
1751

        if local_batch is not None:
1752
1753
1754
1755
1756
1757
1758
1759
1760
            # TODO: handle the case when moe_dense_tp_size != 1
            if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
                local_batch.global_num_tokens = [num_tokens]
                local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
            else:
                local_batch.global_num_tokens = global_num_tokens
                local_batch.global_num_tokens_for_logprob = (
                    global_num_tokens_for_logprob
                )
1761
1762
            local_batch.tbo_split_seq_index = tbo_split_seq_index
            local_batch.global_forward_mode = global_forward_mode
1763

1764
            # Check forward mode for cuda graph
1765
            if not disable_cuda_graph:
Lianmin Zheng's avatar
Lianmin Zheng committed
1766
                local_batch.can_run_dp_cuda_graph = can_cuda_graph
1767

Lianmin Zheng's avatar
Lianmin Zheng committed
1768
        return local_batch, any(is_extend_in_batch)
1769
1770
1771
1772
1773

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
1774
            self.token_to_kv_pool_allocator,
1775
1776
1777
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
1778
            self.spec_algorithm,
1779
            self.server_args.enable_custom_logit_processor,
1780
1781
1782
1783
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1784
1785
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
1786

1787
        num_ready_reqs = 0
1788
        num_abort_reqs = 0
1789
1790
        for req in self.grammar_queue:
            try:
1791
1792
1793
                req.grammar = req.grammar.result(timeout=0.03)
                if req.grammar:
                    self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
1794
1795
                num_ready_reqs += 1
            except futures._base.TimeoutError:
1796
1797
1798
                req.grammar_wait_ct += 1
                if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
                    num_abort_reqs = 1
1799
1800
                break

1801
        if self.server_args.enable_dp_attention:
1802
1803
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
1804
        else:
1805
1806
1807
1808
1809
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
1810
            tensor = torch.tensor([num_ready_reqs, num_abort_reqs], dtype=torch.int32)
1811
1812
1813
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
1814
1815
            num_ready_reqs_max, num_abort_reqs_max = tensor.tolist()

1816
            for i in range(num_ready_reqs, num_ready_reqs_max):
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
                req = self.grammar_queue[i]
                req.grammar = req.grammar.result()
                if req.grammar:
                    self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())

            for i in range(num_ready_reqs, num_ready_reqs + num_abort_reqs_max):
                req = self.grammar_queue[i]
                req.grammar.cancel()
                req.grammar = None
                error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
                logger.error(error_msg)
                req.finished_reason = FINISH_ABORT(
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
                )
            num_ready_reqs = num_ready_reqs_max + num_abort_reqs_max
1832

1833
        self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1834
1835
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1836
1837
1838
1839
1840
1841
1842
    def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
        if batch.next_batch_sampling_info:
            if batch.next_batch_sampling_info.grammars is not None:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
                self.current_stream.synchronize()
            batch.next_batch_sampling_info.sampling_info_done.set()

1843
1844
1845
    def watchdog_thread(self):
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
        self.watchdog_last_forward_ct = 0
1846
        self.watchdog_last_time = time.perf_counter()
1847
1848

        while True:
1849
            current = time.perf_counter()
1850
1851
1852
1853
1854
1855
1856
1857
1858
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if current > self.watchdog_last_time + self.watchdog_timeout:
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)

Lianmin Zheng's avatar
Lianmin Zheng committed
1859
1860
1861
1862
1863
1864
1865
1866
1867
        if not disable_request_logging():
            # Print batch size and memory pool info to check whether there are de-sync issues.
            logger.error(
                f"{self.cur_batch.batch_size()=}, "
                f"{self.cur_batch.reqs=}, "
                f"{self.token_to_kv_pool_allocator.available_size()=}, "
                f"{self.tree_cache.evictable_size()=}, "
            )

1868
        pyspy_dump_schedulers()
Lianmin Zheng's avatar
Lianmin Zheng committed
1869
        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
1870
1871
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
1872
1873

        # Wait for some time so that the parent process can print the error.
1874
1875
1876
        time.sleep(5)
        self.parent_process.send_signal(signal.SIGQUIT)

1877
1878
1879
    def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
        success = self.flush_cache()
        return FlushCacheReqOutput(success=success)
1880

1881
    def flush_cache(self):
1882
        """Flush the memory pool and cache."""
1883
1884
1885
1886
1887
        if (
            len(self.waiting_queue) == 0
            and self.running_batch.is_empty()
            and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
        ):
1888
1889
            self.cur_batch = None
            self.last_batch = None
1890
            self.tree_cache.reset()
1891
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1892
                self.grammar_backend.reset()
1893
            self.req_to_token_pool.clear()
1894
            self.token_to_kv_pool_allocator.clear()
1895
1896
1897

            if not self.spec_algorithm.is_none():
                self.draft_worker.model_runner.req_to_token_pool.clear()
1898
                self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
1899
1900
1901
1902
1903

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
1904
1905
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
1906
1907
1908
1909
1910
1911
1912
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
1913
                f"#running-req: {len(self.running_batch.reqs)}"
1914
1915
1916
1917
            )
            if_success = False
        return if_success

Liangsheng Yin's avatar
Liangsheng Yin committed
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
    def get_load(self):
        # TODO(lsyin): use dynamically maintained num_waiting_tokens
        load = (
            self.max_total_num_tokens
            - self.token_to_kv_pool_allocator.available_size()
            - self.tree_cache.evictable_size()
        )
        load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            load += sum(
                len(req.origin_input_ids)
                for req in self.disagg_prefill_bootstrap_queue.queue
            )
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            load += sum(
                len(req.req.origin_input_ids)
                for req in self.disagg_decode_prealloc_queue.queue
            )

        return load

1939
1940
1941
1942
1943
1944
1945
1946
1947
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )
        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
Liangsheng Yin's avatar
Liangsheng Yin committed
1948
1949
1950
1951

        ret["load"] = self.get_load()

        return GetInternalStateReqOutput(internal_state=ret)
1952
1953
1954
1955
1956

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
1957
                "max_micro_batch_size",
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
1968
1969
1970
1971
1972
1973
1974
1975
            elif k == "max_micro_batch_size" and (
                v > self.max_running_requests // self.pp_size or v < 1
            ):
                logging.warning(
                    f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
                )
                if_success = False
                break
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
            logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
    def handle_rpc_request(self, recv_req: RpcReqInput):
        # Handle RPC requests
        logger.info(
            f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
        )

        success = True
        exec = None
        try:
            func = getattr(self, recv_req.method)
            func(recv_req.parameters)
        except Exception as e:
            success = False
            exec = e
            logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")

        barrier()
        return RpcReqOutput(success, "" if not exec else str(exec))

    def save_remote_model(self, params):
        url = params["url"]

2013
        worker = self.tp_worker.worker
2014
2015
2016
2017

        worker.model_runner.save_remote_model(url)

    def save_sharded_model(self, params):
2018
        worker = self.tp_worker.worker
2019
2020
2021
2022
2023
2024
2025

        worker.model_runner.save_sharded_model(
            path=params["path"],
            pattern=params["pattern"],
            max_size=params["max_size"],
        )

2026
    def abort_request(self, recv_req: AbortReq):
Lianmin Zheng's avatar
Lianmin Zheng committed
2027
2028
        # TODO(lmzheng): abort the requests in the grammar queue.

2029
        # Delete requests in the waiting queue
Lianmin Zheng's avatar
Lianmin Zheng committed
2030
        to_del = []
2031
        for i, req in enumerate(self.waiting_queue):
Lianmin Zheng's avatar
Lianmin Zheng committed
2032
2033
            if req.rid.startswith(recv_req.rid):
                to_del.append(i)
2034

Lianmin Zheng's avatar
Lianmin Zheng committed
2035
        # Sort in reverse order to avoid index issues when deleting
Lianmin Zheng's avatar
Lianmin Zheng committed
2036
        for i in reversed(to_del):
Lianmin Zheng's avatar
Lianmin Zheng committed
2037
            req = self.waiting_queue.pop(i)
Lianmin Zheng's avatar
Lianmin Zheng committed
2038
            self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
2039
            logger.debug(f"Abort queued request. {req.rid=}")
2040
2041

        # Delete requests in the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
2042
2043
2044
2045
2046
2047
        if self.cur_batch is self.running_batch or self.cur_batch is None:
            reqs = self.running_batch.reqs
        else:
            reqs = self.running_batch.reqs + self.cur_batch.reqs

        for req in reqs:
Lianmin Zheng's avatar
Lianmin Zheng committed
2048
2049
2050
            if req.rid.startswith(recv_req.rid) and not req.finished():
                logger.debug(f"Abort running request. {req.rid=}")
                req.to_abort = True
2051

2052
2053
2054
    def _pause_engine(self) -> Tuple[List[Req], int]:
        raise NotImplementedError()

Chayenne's avatar
Chayenne committed
2055
2056
2057
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
2058
2059
2060
2061
2062
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
2063
        return UpdateWeightFromDiskReqOutput(success, message, 0)
2064

2065
2066
2067
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
2068
        return InitWeightsUpdateGroupReqOutput(success, message)
2069
2070

    def update_weights_from_distributed(
2071
2072
2073
        self,
        recv_req: UpdateWeightsFromDistributedReqInput,
    ) -> Tuple[bool, str]:
2074
2075
2076
2077
2078
2079
2080
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
2081
        return UpdateWeightsFromDistributedReqOutput(success, message)
2082

2083
2084
2085
2086
2087
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        """Update the online model parameter from tensors."""
        success, message = self.tp_worker.update_weights_from_tensor(recv_req)
        # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
        if success:
2088
2089
2090
            if recv_req.flush_cache:
                flash_cache_success = self.flush_cache()
                assert flash_cache_success, "Cache flush failed after updating weights"
2091
2092
        else:
            logger.error(message)
2093
        return UpdateWeightsFromTensorReqOutput(success, message)
2094

2095
2096
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
2097
        return GetWeightsByNameReqOutput(parameter)
2098

2099
    def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
2100
2101
2102
        self.memory_saver_adapter.check_validity(
            caller_name="release_memory_occupation"
        )
2103
2104
2105
2106
2107
        self.stashed_model_static_state = _export_static_state(
            self.tp_worker.worker.model_runner.model
        )
        self.memory_saver_adapter.pause()
        self.flush_cache()
2108
        return ReleaseMemoryOccupationReqOutput()
2109

2110
    def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
2111
        self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
2112
2113
2114
2115
2116
        self.memory_saver_adapter.resume()
        _import_static_state(
            self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
        )
        del self.stashed_model_static_state
2117
2118
        return ResumeMemoryOccupationReqOutput()

2119
2120
2121
2122
2123
2124
2125
    def slow_down(self, recv_req: SlowDownReqInput):
        t = recv_req.forward_sleep_time
        if t is not None and t <= 0:
            t = None
        self.forward_sleep_time = t
        return SlowDownReqOutput()

2126
    def profile(self, recv_req: ProfileReq):
2127
        if recv_req.type == ProfileReqType.START_PROFILE:
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
            if recv_req.profile_by_stage:
                return self.init_profile(
                    recv_req.output_dir,
                    recv_req.num_steps,
                    recv_req.activities,
                    recv_req.with_stack,
                    recv_req.record_shapes,
                    recv_req.profile_by_stage,
                )
            else:
                self.init_profile(
                    recv_req.output_dir,
                    recv_req.num_steps,
                    recv_req.activities,
                    recv_req.with_stack,
                    recv_req.record_shapes,
                    recv_req.profile_by_stage,
                )
                return self.start_profile(True)
2147
        else:
2148
2149
            return self.stop_profile()

2150
    def init_profile(
2151
2152
2153
2154
        self,
        output_dir: Optional[str],
        num_steps: Optional[int],
        activities: Optional[List[str]],
2155
2156
        with_stack: Optional[bool],
        record_shapes: Optional[bool],
2157
2158
2159
        profile_by_stage: bool,
    ) -> ProfileReqOutput:
        if self.profile_in_progress:
2160
2161
2162
2163
2164
            return ProfileReqOutput(
                success=False,
                message="Profiling is already in progress. Call /stop_profile first.",
            )

2165
2166
        self.profile_by_stage = profile_by_stage

2167
2168
2169
2170
2171
2172
        if output_dir is None:
            output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
        if activities is None:
            activities = ["CPU", "GPU"]

        self.torch_profiler_output_dir = output_dir
2173
2174
        self.torch_profiler_with_stack = with_stack
        self.torch_profiler_record_shapes = record_shapes
2175
        self.profiler_activities = activities
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195

        if num_steps:
            self.profile_steps = num_steps
            if self.profile_by_stage:
                self.profiler_target_prefill_ct = num_steps
                self.profiler_target_decode_ct = num_steps
                self.profiler_prefill_ct = 0
                self.profiler_decode_ct = 0
            else:
                self.profiler_target_forward_ct = self.forward_ct + num_steps
            # The caller will be notified when reaching profiler_target_forward_ct
        else:
            self.profiler_target_forward_ct = None

        return ProfileReqOutput(success=True, message="Succeeded")

    def start_profile(
        self, stage: Optional[ForwardMode] = None
    ) -> ProfileReqOutput | None:
        stage_str = f" for {stage.__str__()}" if stage else ""
2196
        logger.info(
2197
            f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir}",
2198
2199
        )

2200
2201
2202
2203
        activities = self.profiler_activities
        with_stack = self.torch_profiler_with_stack
        record_shapes = self.torch_profiler_record_shapes

2204
2205
2206
2207
2208
2209
2210
2211
        activity_map = {
            "CPU": torch.profiler.ProfilerActivity.CPU,
            "GPU": torch.profiler.ProfilerActivity.CUDA,
        }
        torchprof_activities = [
            activity_map[a] for a in activities if a in activity_map
        ]

2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
        if "RPD" in activities:
            from rpdTracerControl import rpdTracerControl

            rpdTracerControl.skipCreate()

            self.rpd_profile_path = os.path.join(
                self.torch_profiler_output_dir,
                "rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
            )

            if self.tp_rank == 0:
                import sqlite3

                from rocpd.schema import RocpdSchema

                if os.path.exists("trace.rpd"):
                    os.unlink("trace.rpd")
                schema = RocpdSchema()
                connection = sqlite3.connect("trace.rpd")
                schema.writeSchema(connection)
                connection.commit()
                del connection
            torch.distributed.barrier(self.tp_cpu_group)

            self.rpd_profiler = rpdTracerControl()
            self.rpd_profiler.setPythonTrace(True)
            self.rpd_profiler.start()
            self.rpd_profiler.rangePush("", "rpd profile range", "")
            self.profile_in_progress = True
        elif torchprof_activities:
2242
2243
            self.torch_profiler = torch.profiler.profile(
                activities=torchprof_activities,
2244
2245
                with_stack=with_stack if with_stack is not None else True,
                record_shapes=record_shapes if record_shapes is not None else False,
2246
2247
            )
            self.torch_profiler.start()
2248
            self.profile_in_progress = True
2249
2250
2251

        if "MEM" in activities:
            torch.cuda.memory._record_memory_history(max_entries=100000)
2252
            self.profile_in_progress = True
2253

2254
2255
2256
        if "CUDA_PROFILER" in activities:
            torch.cuda.cudart().cudaProfilerStart()

2257
        return ProfileReqOutput(success=True, message="Succeeded")
2258

2259
2260
2261
2262
    def stop_profile(
        self, stage: Optional[ForwardMode] = None
    ) -> ProfileReqOutput | None:
        if not self.profile_in_progress:
2263
2264
2265
2266
            return ProfileReqOutput(
                success=False,
                message="Profiling is not in progress. Call /start_profile first.",
            )
2267

2268
2269
        stage_suffix = f"-{stage.__str__()}" if stage else ""
        logger.info("Stop profiling" + stage_suffix + "...")
2270
2271
2272
2273
2274
        if self.torch_profiler is not None:
            self.torch_profiler.stop()
            self.torch_profiler.export_chrome_trace(
                os.path.join(
                    self.torch_profiler_output_dir,
2275
2276
2277
2278
                    str(time.time())
                    + f"-TP-{self.tp_rank}"
                    + stage_suffix
                    + ".trace.json.gz",
2279
2280
                )
            )
2281
2282
2283
2284
2285
2286
            torch.distributed.barrier(self.tp_cpu_group)

        if self.rpd_profiler is not None:
            self.rpd_profiler.rangePop()
            self.rpd_profiler.stop()
            self.rpd_profiler.flush()
2287

2288
2289
2290
2291
2292
2293
2294
2295
2296
            torch.distributed.barrier(self.tp_cpu_group)
            if self.tp_rank == 0:
                from sglang.srt.utils import rpd_to_chrome_trace

                rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
            self.rpd_profiler = None
            self.rpd_profiler_path = None

        if self.profiler_activities is not None and "MEM" in self.profiler_activities:
2297
            memory_profile_path = os.path.join(
2298
                self.torch_profiler_output_dir,
2299
2300
2301
2302
                str(time.time())
                + f"-TP-{self.tp_rank}-memory"
                + stage_suffix
                + ".pickle",
2303
2304
2305
2306
            )
            torch.cuda.memory._dump_snapshot(memory_profile_path)
            torch.cuda.memory._record_memory_history(enabled=None)

2307
2308
2309
        if "CUDA_PROFILER" in self.profiler_activities:
            torch.cuda.cudart().cudaProfilerStop()

2310
2311
2312
        logger.info(
            "Profiling done. Traces are saved to: %s",
            self.torch_profiler_output_dir,
2313
        )
2314
        self.torch_profiler = None
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
        self.profile_in_progress = False

        return ProfileReqOutput(success=True, message="Succeeded.")

    def _profile_batch_predicate(self, batch):
        if self.profile_by_stage:
            if batch.forward_mode.is_prefill():
                if self.profiler_prefill_ct == 0:
                    self.start_profile(batch.forward_mode)
                self.profiler_prefill_ct += 1
                if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
                    if self.profile_in_progress:
                        self.stop_profile(stage=ForwardMode.EXTEND)
            elif batch.forward_mode.is_decode():
                if self.profiler_decode_ct == 0:
                    if self.profile_in_progress:
                        # force trace flush
                        self.stop_profile(ForwardMode.EXTEND)
                    self.start_profile(batch.forward_mode)
                self.profiler_decode_ct += 1
                if self.profiler_decode_ct > self.profiler_target_decode_ct:
                    if self.profile_in_progress:
                        self.stop_profile(stage=ForwardMode.DECODE)
            else:
                raise RuntimeError("unsupported profile stage")
        else:
            # Check profiler
            if (
                self.profiler_target_forward_ct
                and self.profiler_target_forward_ct <= self.forward_ct
            ):
                self.stop_profile()
2347

2348
2349
    def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
        if recv_req == ExpertDistributionReq.START_RECORD:
2350
            get_global_expert_distribution_recorder().start_record()
2351
        elif recv_req == ExpertDistributionReq.STOP_RECORD:
2352
            get_global_expert_distribution_recorder().stop_record()
2353
        elif recv_req == ExpertDistributionReq.DUMP_RECORD:
2354
            get_global_expert_distribution_recorder().dump_record()
2355
2356
        else:
            raise ValueError("Unrecognized ExpertDistributionReq value")
2357
        return ExpertDistributionReqOutput()
2358

2359
    def open_session(self, recv_req: OpenSessionReqInput):
2360
2361
2362
2363
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
2364
            return OpenSessionReqOutput(session_id, False)
2365
        elif session_id is None:
2366
            logger.warning("session id is None, cannot open.")
2367
            return OpenSessionReqOutput(session_id, False)
2368
2369
2370
2371
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
2372
            return OpenSessionReqOutput(session_id, True)
2373
2374
2375
2376
2377
2378
2379
2380
2381

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

2382
2383
    def get_print_prefix(self):
        prefix = ""
2384
2385
        if self.attn_dp_rank is not None:
            prefix += f" DP{self.attn_dp_rank}"
2386
2387
2388
2389
2390
2391
        if self.server_args.tp_size > 1:
            prefix += f" TP{self.tp_rank}"
        if self.pp_size > 1:
            prefix += f" PP{self.pp_rank}"
        return prefix

2392
2393
2394
2395
2396
2397
2398
    def _publish_kv_events(self):
        if self.enable_kv_cache_events:
            events = self.tree_cache.take_events()
            if events:
                batch = KVEventBatch(ts=time.time(), events=events)
                self.kv_event_publisher.publish(batch)

2399

2400
2401
2402
2403
def is_health_check_generate_req(recv_req):
    return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")


2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
def _export_static_state(model):
    return dict(
        buffers=[
            (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
        ]
    )


def _import_static_state(model, static_params):
    self_named_buffers = dict(model.named_buffers())
    for name, tensor in static_params["buffers"]:
        self_named_buffers[name][...] = tensor


2418
2419
2420
2421
2422
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
2423
    pp_rank: int,
2424
    dp_rank: Optional[int],
2425
    pipe_writer,
2426
):
2427
    # Generate the prefix
2428
2429
2430
2431
2432
2433
2434
    prefix = ""
    if dp_rank is not None:
        prefix += f" DP{dp_rank}"
    if server_args.tp_size > 1:
        prefix += f" TP{tp_rank}"
    if server_args.pp_size > 1:
        prefix += f" PP{pp_rank}"
2435

2436
    # Config the process
2437
    kill_itself_when_parent_died()
2438
    setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
2439
    faulthandler.enable()
2440
    parent_process = psutil.Process().parent()
2441

2442
2443
2444
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
2445

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
2446
    # Configure the logger
2447
    configure_logger(server_args, prefix=prefix)
2448
    suppress_other_loggers()
2449

2450
    # Set cpu affinity to this gpu process
2451
2452
2453
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

2454
2455
2456
2457
    embedding_cache_size = 100
    if "SGLANG_VLM_CACHE_SIZE_MB" in os.environ:
        embedding_cache_size = int(os.environ["SGLANG_VLM_CACHE_SIZE_MB"])
    init_embedding_cache(embedding_cache_size * 1024 * 1024)
2458
    # Create a scheduler and run the event loop
2459
    try:
2460
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
2461
        pipe_writer.send(
Mick's avatar
Mick committed
2462
2463
2464
2465
2466
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2467
        )
Byron Hsu's avatar
Byron Hsu committed
2468
2469
2470
        disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode

        if disaggregation_mode == DisaggregationMode.NULL:
2471
2472
2473
            if server_args.pp_size > 1:
                scheduler.event_loop_pp()
            elif scheduler.enable_overlap:
Byron Hsu's avatar
Byron Hsu committed
2474
2475
2476
2477
                scheduler.event_loop_overlap()
            else:
                scheduler.event_loop_normal()
        elif disaggregation_mode == DisaggregationMode.PREFILL:
2478
2479
2480
2481
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_prefill()
            else:
                scheduler.event_loop_normal_disagg_prefill()
2482

Byron Hsu's avatar
Byron Hsu committed
2483
        elif disaggregation_mode == DisaggregationMode.DECODE:
2484
2485
2486
2487
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_decode()
            else:
                scheduler.event_loop_normal_disagg_decode()
Byron Hsu's avatar
Byron Hsu committed
2488

2489
    except Exception:
2490
2491
2492
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)