scheduler.py 99.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
import time
23
from collections import defaultdict, deque
Lianmin Zheng's avatar
Lianmin Zheng committed
24
from concurrent import futures
25
from dataclasses import dataclass
26
from http import HTTPStatus
27
from pathlib import Path
28
from types import SimpleNamespace
29
from typing import Dict, List, Optional, Tuple, Union
30

31
import psutil
32
import setproctitle
33
import torch
34
import zmq
35
from torch.distributed import barrier
36

37
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
38
from sglang.srt.configs.model_config import ModelConfig
39
40
41
42
from sglang.srt.constrained.base_grammar_backend import (
    INVALID_GRAMMAR_OBJ,
    create_grammar_backend,
)
Byron Hsu's avatar
Byron Hsu committed
43
44
45
46
47
from sglang.srt.disaggregation.decode import (
    DecodePreallocQueue,
    DecodeTransferQueue,
    SchedulerDisaggregationDecodeMixin,
)
48
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
Byron Hsu's avatar
Byron Hsu committed
49
50
51
52
53
54
from sglang.srt.disaggregation.prefill import (
    PrefillBootstrapQueue,
    SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
55
    MetadataBuffers,
Byron Hsu's avatar
Byron Hsu committed
56
    ReqToMetadataIdxAllocator,
57
    TransferBackend,
58
    prepare_abort,
Byron Hsu's avatar
Byron Hsu committed
59
)
60
from sglang.srt.distributed import get_pp_group, get_world_group
xm:D's avatar
xm:D committed
61
62
63
64
65
from sglang.srt.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
66
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
67
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
68
69
70
from sglang.srt.managers.expert_distribution import (
    get_global_expert_distribution_recorder,
)
71
72
from sglang.srt.managers.io_struct import (
    AbortReq,
73
    CloseSessionReqInput,
74
    ExpertDistributionReq,
75
    ExpertDistributionReqOutput,
76
77
    FlushCacheReqInput,
    FlushCacheReqOutput,
78
79
    GetInternalStateReq,
    GetInternalStateReqOutput,
80
81
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
82
    HealthCheckOutput,
83
84
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
85
86
    OpenSessionReqInput,
    OpenSessionReqOutput,
87
    ProfileReq,
88
89
    ProfileReqOutput,
    ProfileReqType,
90
91
92
93
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
94
95
    RpcReqInput,
    RpcReqOutput,
96
97
    SetInternalStateReq,
    SetInternalStateReqOutput,
98
99
    SlowDownReqInput,
    SlowDownReqOutput,
100
101
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
102
103
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
104
105
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
106
107
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
108
)
109
from sglang.srt.managers.mm_utils import init_embedding_cache
110
111
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
Mick's avatar
Mick committed
112
    MultimodalInputs,
113
114
    Req,
    ScheduleBatch,
115
    global_server_args_dict,
116
)
117
118
119
120
121
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
122
123
124
from sglang.srt.managers.scheduler_output_processor_mixin import (
    SchedulerOutputProcessorMixin,
)
125
from sglang.srt.managers.session_controller import Session
126
from sglang.srt.managers.tp_worker import TpModelWorker
127
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
128
from sglang.srt.managers.utils import validate_input_length
129
from sglang.srt.mem_cache.chunk_cache import ChunkCache
130
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
131
from sglang.srt.mem_cache.radix_cache import RadixCache
132
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
Lianmin Zheng's avatar
Lianmin Zheng committed
133
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
134
from sglang.srt.reasoning_parser import ReasoningParser
135
from sglang.srt.server_args import PortArgs, ServerArgs
136
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
137
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
138
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
139
from sglang.srt.utils import (
140
    DeepEPMode,
141
    DynamicGradMode,
142
143
    broadcast_pyobj,
    configure_logger,
Lianmin Zheng's avatar
Lianmin Zheng committed
144
    disable_request_logging,
145
    get_available_gpu_memory,
146
    get_bool_env_var,
147
    get_zmq_socket,
Lianmin Zheng's avatar
Lianmin Zheng committed
148
    kill_itself_when_parent_died,
149
    point_to_point_pyobj,
150
    pyspy_dump_schedulers,
151
    set_gpu_proc_affinity,
152
153
154
    set_random_seed,
    suppress_other_loggers,
)
155
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
156
157
158

logger = logging.getLogger(__name__)

159
# Test retract decode for debugging purposes
160
161
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
162
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
163

164

165
166
@dataclass
class GenerationBatchResult:
167
168
169
    logits_output: Optional[LogitsProcessorOutput]
    pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
    next_token_ids: Optional[List[int]]
170
171
    extend_input_len_per_req: List[int]
    extend_logprob_start_len_per_req: List[int]
172
    bid: int
173
    can_run_cuda_graph: bool
174
175
176
177
178
179
180
181


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


Byron Hsu's avatar
Byron Hsu committed
182
183
184
185
186
class Scheduler(
    SchedulerOutputProcessorMixin,
    SchedulerDisaggregationDecodeMixin,
    SchedulerDisaggregationPrefillMixin,
):
187
188
189
190
191
192
193
194
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
195
        pp_rank: int,
196
        dp_rank: Optional[int],
197
198
    ):
        # Parse args
199
        self.server_args = server_args
200
        self.tp_rank = tp_rank
201
        self.pp_rank = pp_rank
202
        self.tp_size = server_args.tp_size
203
204
        self.pp_size = server_args.pp_size
        self.dp_size = server_args.dp_size
205
206
207
        self.schedule_policy = server_args.schedule_policy
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
208
        self.enable_overlap = not server_args.disable_overlap_schedule
209
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
210
        self.enable_metrics = server_args.enable_metrics
211
        self.enable_kv_cache_events = server_args.kv_events_config is not None
212
        self.stream_interval = server_args.stream_interval
213
214
215
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
216
217
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
Lianmin Zheng's avatar
Lianmin Zheng committed
218
        self.page_size = server_args.page_size
219
220
        self.dp_size = server_args.dp_size
        self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
221
222
223
224
225
226
227
228
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

229
230
        # Init inter-process communication
        context = zmq.Context(2)
231
        if self.pp_rank == 0 and self.attn_tp_rank == 0:
232
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
233
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
234
            )
235
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
236
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
237
            )
238

239
            if server_args.skip_tokenizer_init:
240
                # Directly send to the TokenizerManager
241
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
242
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
243
244
                )
            else:
245
                # Send to the DetokenizerManager
246
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
247
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
248
                )
249
250
251
252

            self.recv_from_rpc = get_zmq_socket(
                context, zmq.DEALER, port_args.rpc_ipc_name, False
            )
253
        else:
254
            self.recv_from_tokenizer = None
255
            self.recv_from_rpc = None
256
257
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
258
259

        # Init tokenizer
260
        self.init_tokenizer()
261

262
263
264
265
266
267
268
269
270
        # Set reasoning_parser and think_end_id if --reasoning_parser is enabled
        if self.server_args.reasoning_parser and self.tokenizer:
            reasoning_parser = ReasoningParser(
                model_type=self.server_args.reasoning_parser, stream_reasoning=False
            )
            self.tokenizer.think_end_id = self.tokenizer.encode(
                reasoning_parser.detector.think_end_token, add_special_tokens=False
            )[0]

271
272
273
274
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
275

276
        # Launch a tensor parallel worker
277
        if self.enable_overlap:
278
            TpWorkerClass = TpModelWorkerClient
279
280
        else:
            TpWorkerClass = TpModelWorker
281

282
        self.tp_worker = TpWorkerClass(
283
            server_args=server_args,
284
285
            gpu_id=gpu_id,
            tp_rank=tp_rank,
286
            pp_rank=pp_rank,
287
            dp_rank=dp_rank,
288
            nccl_port=port_args.nccl_port,
289
        )
290

291
        # Launch a draft worker for speculative decoding
292
293
294
295
296
297
298
299
300
301
302
303
304
305
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
        else:
            self.draft_worker = None

306
        # Get token and memory info from the model worker
307
308
309
310
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
311
            self.max_req_len,
312
313
            self.max_req_input_len,
            self.random_seed,
314
            self.device,
315
316
317
318
319
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
320
321
322
323
324
325
326
327
        if global_server_args_dict["max_micro_batch_size"] is None:
            global_server_args_dict["max_micro_batch_size"] = max(
                self.max_running_requests // server_args.pp_size, 1
            )

        self.tp_group = self.tp_worker.get_tp_group()
        self.tp_cpu_group = self.tp_group.cpu_group
        self.attn_tp_group = self.tp_worker.get_attention_tp_group()
328
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
329
330
331
        self.pp_group = get_pp_group()
        self.world_group = get_world_group()

332
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
333
        global_server_args_dict.update(worker_global_server_args_dict)
334
        set_random_seed(self.random_seed)
335

336
        # Print debug info
337
        if tp_rank == 0:
338
339
340
            avail_mem = get_available_gpu_memory(
                self.device, self.gpu_id, empty_cache=False
            )
341
342
343
344
345
            logger.info(
                f"max_total_num_tokens={self.max_total_num_tokens}, "
                f"chunked_prefill_size={server_args.chunked_prefill_size}, "
                f"max_prefill_tokens={self.max_prefill_tokens}, "
                f"max_running_requests={self.max_running_requests}, "
346
347
                f"context_len={self.model_config.context_len}, "
                f"available_gpu_mem={avail_mem:.2f} GB"
348
            )
349

Lianmin Zheng's avatar
Lianmin Zheng committed
350
        # Init memory pool and cache
351
        self.init_memory_pool_and_cache()
352
353
354

        # Init running status
        self.waiting_queue: List[Req] = []
355
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
356
        self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
357
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
358
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
359
        # The last forward batch
360
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
361
362
        self.forward_ct = 0
        self.forward_ct_decode = 0
363
        self.num_generated_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
364
        self.num_prefill_tokens = 0
365
366
        self.last_decode_stats_tic = time.perf_counter()
        self.last_prefill_stats_tic = time.perf_counter()
367
        self.return_health_check_ct = 0
368
        self.current_stream = torch.get_device_module(self.device).current_stream()
369
370
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
371
        self.forward_sleep_time = None
372

373
        # Init session info
374
        self.sessions: Dict[str, Session] = {}
375
376
377

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
378
379
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
380
        self.chunked_req = None
381
382
383
384
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
385
        # Init the grammar backend for constrained generation
386
        self.grammar_queue: List[Req] = []
387
        if not server_args.skip_tokenizer_init:
388
389
390
            self.grammar_backend = create_grammar_backend(
                server_args, self.tokenizer, self.model_config.vocab_size
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
391
392
        else:
            self.grammar_backend = None
393

394
        # Init schedule policy and new token estimation
395
        self.policy = SchedulePolicy(
Lianmin Zheng's avatar
Lianmin Zheng committed
396
397
398
            self.schedule_policy,
            self.tree_cache,
            self.enable_hierarchical_cache,
399
        )
400
401
402
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
403
404
        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
405
406
            * server_args.schedule_conservativeness,
            1.0,
407
        )
408
409
410
411
412
413
414
415
416
417
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
418
419
420
421
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
422
        self.parent_process = psutil.Process().parent()
Lianmin Zheng's avatar
Lianmin Zheng committed
423

424
        # Init memory saver
425
426
427
428
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )

429
        # Init profiler
430
431
        self.torch_profiler = None
        self.torch_profiler_output_dir: Optional[str] = None
432
        self.profiler_activities: Optional[List[str]] = None
433
        self.profile_id: Optional[str] = None
434
        self.profiler_target_forward_ct: Optional[int] = None
435
436
437
438
439
440
441
442
        self.profiler_target_prefill_ct: Optional[int] = None
        self.profiler_target_decode_ct: Optional[int] = None
        self.profiler_prefill_ct: Optional[int] = None
        self.profiler_decode_ct: Optional[int] = None
        self.profile_by_stage: bool = False
        self.profile_steps: Optional[int] = None
        self.profile_in_progress: bool = False
        self.rpd_profiler = None
443

444
        # Init metrics stats
445
        self.init_metrics()
446
        self.init_kv_events(server_args.kv_events_config)
447

448
449
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
450
451
452
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
453
                (FlushCacheReqInput, self.flush_cache_wrapped),
454
                (AbortReq, self.abort_request),
455
456
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
457
458
459
460
461
462
463
464
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
465
466
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
467
                (SlowDownReqInput, self.slow_down),
468
                (ProfileReq, self.profile),
469
                (GetInternalStateReq, self.get_internal_state),
470
                (SetInternalStateReq, self.set_internal_state),
471
                (RpcReqInput, self.handle_rpc_request),
472
                (ExpertDistributionReq, self.expert_distribution_handle),
473
474
475
            ]
        )

Byron Hsu's avatar
Byron Hsu committed
476
477
478
479
480
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
        self.init_disaggregation()

481
482
    def init_tokenizer(self):
        server_args = self.server_args
Lianmin Zheng's avatar
Lianmin Zheng committed
483

484
        self.model_config = ModelConfig.from_server_args(server_args)
485
        self.is_generation = self.model_config.is_generation
486

487
488
489
490
491
492
493
494
495
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if self.model_config.is_multimodal:
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
496
                    use_fast=not server_args.disable_fast_image_processor,
497
                )
xm:D's avatar
xm:D committed
498
                self.tokenizer = get_tokenizer_from_processor(self.processor)
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )

    def init_memory_pool_and_cache(self):
        server_args = self.server_args

        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            self.tp_worker.get_memory_pool()
        )

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
521
                page_size=self.page_size,
522
523
524
525
526
527
            )
        else:
            if self.enable_hierarchical_cache:
                self.tree_cache = HiRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
528
                    tp_cache_group=self.tp_cpu_group,
529
                    page_size=self.page_size,
530
                    hicache_ratio=server_args.hicache_ratio,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
531
532
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
533
534
535
536
537
                )
            else:
                self.tree_cache = RadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
538
                    page_size=self.page_size,
539
                    disable=server_args.disable_radix_cache,
540
                    enable_kv_cache_events=self.enable_kv_cache_events,
541
542
543
544
545
546
547
548
549
550
551
552
                )

        self.decode_mem_cache_buf_multiplier = (
            1
            if self.spec_algorithm.is_none()
            else (
                server_args.speculative_num_draft_tokens
                + (
                    server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                )
            )
553
        )
554
555
556

    def init_metrics(self):
        self.last_gen_throughput: float = 0.0
Lianmin Zheng's avatar
Lianmin Zheng committed
557
        self.last_input_throughput: float = 0.0
558
559
560
561
562
563
564
565
566
567
568
569
570
571
        self.step_time_dict = defaultdict(list)  # Dict[batch size -> step time]
        self.spec_num_total_accepted_tokens = 0
        self.spec_num_total_forward_ct = 0
        self.cum_spec_accept_length = 0
        self.cum_spec_accept_count = 0
        self.stats = SchedulerStats()
        if self.enable_metrics:
            engine_type = "unified"
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    "engine_type": engine_type,
                },
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
572

573
574
    def init_kv_events(self, kv_events_config: Optional[str]):
        if self.enable_kv_cache_events:
575
576
577
            self.kv_event_publisher = EventPublisherFactory.create(
                kv_events_config, self.attn_dp_rank
            )
578

Byron Hsu's avatar
Byron Hsu committed
579
    def init_disaggregation(self):
580
581
582
583
        self.transfer_backend = TransferBackend(
            self.server_args.disaggregation_transfer_backend
        )

Byron Hsu's avatar
Byron Hsu committed
584
585
586
587
588
589
590
        if (
            self.disaggregation_mode == DisaggregationMode.DECODE
        ):  # *2 for the headroom.
            buffer_size = (self.req_to_token_pool.size) * 2
            req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
                buffer_size
            )
591
            self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
Byron Hsu's avatar
Byron Hsu committed
592
593
594

            # The decode requests polling kv cache
            self.disagg_decode_transfer_queue = DecodeTransferQueue(
595
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
596
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
597
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
598
599
                scheduler=self,
                tree_cache=self.tree_cache,
Byron Hsu's avatar
Byron Hsu committed
600
601
602
603
604
605
            )

            # The decode requests pending for pre-allocation
            self.disagg_decode_prealloc_queue = DecodePreallocQueue(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Byron Hsu's avatar
Byron Hsu committed
606
607
608
609
610
                draft_token_to_kv_pool=(
                    None
                    if self.draft_worker is None
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
611
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
612
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
613
614
615
                scheduler=self,
                transfer_queue=self.disagg_decode_transfer_queue,
                tree_cache=self.tree_cache,
616
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
617
618
619
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
620
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
621
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
622
623
624
625

            # Metric for pre-allocation
            self.num_tokens_pre_allocated = 0

Byron Hsu's avatar
Byron Hsu committed
626
627
628
629
630
631
        elif self.disaggregation_mode == DisaggregationMode.PREFILL:
            # *2 for the headroom.
            buffer_size = self.max_running_requests * 2
            req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
                buffer_size
            )
632
            self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
Byron Hsu's avatar
Byron Hsu committed
633

Liangsheng Yin's avatar
Liangsheng Yin committed
634
            self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
Byron Hsu's avatar
Byron Hsu committed
635
                token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
Byron Hsu's avatar
Byron Hsu committed
636
637
638
639
640
                draft_token_to_kv_pool=(
                    None
                    if self.draft_worker is None
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
641
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
642
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
643
644
645
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
646
                gloo_group=self.attn_tp_cpu_group,
647
                transfer_backend=self.transfer_backend,
648
                scheduler=self,
Byron Hsu's avatar
Byron Hsu committed
649
650
            )
            # The prefill requests that are in the middle of kv sending
651
            self.disagg_prefill_inflight_queue: List[Req] = []
Byron Hsu's avatar
Byron Hsu committed
652

653
    @DynamicGradMode()
654
    def event_loop_normal(self):
655
        """A normal scheduler loop."""
656
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
657
658
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
659

660
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
661
            self.cur_batch = batch
662
663
664
665

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
666
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
667
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
668
                self.check_memory()
669
                self.new_token_ratio = self.init_new_token_ratio
670
671

            self.last_batch = batch
672

673
    @DynamicGradMode()
Lianmin Zheng's avatar
Lianmin Zheng committed
674
    def event_loop_overlap(self):
675
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
676
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
677
678
679
680
681
682
683

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
684

Lianmin Zheng's avatar
Lianmin Zheng committed
685
            if batch:
686
                batch.launch_done = threading.Event()
Lianmin Zheng's avatar
Lianmin Zheng committed
687
                result = self.run_batch(batch)
688
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
689

690
                if self.last_batch is None:
691
                    # Create a dummy first batch to start the pipeline for overlap schedule.
692
693
694
695
696
697
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
698
                    self.process_batch_result(tmp_batch, None, batch.launch_done)
699

Lianmin Zheng's avatar
Lianmin Zheng committed
700
            if self.last_batch:
701
                # Process the results of the last batch
702
                tmp_batch, tmp_result = self.result_queue.popleft()
703
704
705
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
706
707
708
709
                # NOTE: we should use current launched batch's launch_done event Instead of the last batch's
                self.process_batch_result(
                    tmp_batch, tmp_result, batch.launch_done if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
710
            elif batch is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
711
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
712
                self.check_memory()
713
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
714
715
716

            self.last_batch = batch

717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
    @DynamicGradMode()
    def event_loop_pp(self):
        """A non-overlap scheduler loop for pipeline parallelism."""
        mbs = [None] * self.pp_size
        last_mbs = [None] * self.pp_size
        self.running_mbs = [
            ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
        ]
        bids = [None] * self.pp_size
        pp_outputs: Optional[PPProxyTensors] = None
        while True:
            server_is_idle = True
            for mb_id in range(self.pp_size):
                self.running_batch = self.running_mbs[mb_id]
                self.last_batch = last_mbs[mb_id]

                recv_reqs = self.recv_requests()
                self.process_input_requests(recv_reqs)
                mbs[mb_id] = self.get_next_batch_to_run()
                self.running_mbs[mb_id] = self.running_batch

                self.cur_batch = mbs[mb_id]
                if self.cur_batch:
                    server_is_idle = False
                    result = self.run_batch(self.cur_batch)

743
                # (last rank) send the outputs to the next step
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
                if self.pp_group.is_last_rank:
                    if self.cur_batch:
                        next_token_ids, bids[mb_id] = (
                            result.next_token_ids,
                            result.bid,
                        )
                        pp_outputs = PPProxyTensors(
                            {
                                "next_token_ids": next_token_ids,
                            }
                        )
                        # send the output from the last round to let the next stage worker run post processing
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                # receive outputs and post-process (filter finished reqs) the coming microbatch
                next_mb_id = (mb_id + 1) % self.pp_size
                next_pp_outputs = None
                if mbs[next_mb_id] is not None:
                    next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
                        self.pp_group.recv_tensor_dict(
                            all_gather_group=self.attn_tp_group
                        )
                    )
                    mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
                    output_result = GenerationBatchResult(
                        logits_output=None,
                        pp_hidden_states_proxy_tensors=None,
                        next_token_ids=next_pp_outputs["next_token_ids"],
                        extend_input_len_per_req=None,
                        extend_logprob_start_len_per_req=None,
                        bid=bids[next_mb_id],
778
                        can_run_cuda_graph=result.can_run_cuda_graph,
779
780
781
782
                    )
                    self.process_batch_result(mbs[next_mb_id], output_result)
                    last_mbs[next_mb_id] = mbs[next_mb_id]

783
                # (not last rank)
784
785
786
                if not self.pp_group.is_last_rank:
                    if self.cur_batch:
                        bids[mb_id] = result.bid
787
788
                    # carry the outputs to the next stage
                    # send the outputs from the last round to let the next stage worker run post processing
789
790
791
792
793
794
795
                    if pp_outputs:
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                    # send out reqs to the next stage
796
                    dp_offset = self.attn_dp_rank * self.attn_tp_size
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
                    if self.attn_tp_rank == 0:
                        point_to_point_pyobj(
                            recv_reqs,
                            self.pp_rank * self.tp_size + dp_offset,
                            self.world_group.cpu_group,
                            self.pp_rank * self.tp_size + dp_offset,
                            (self.pp_rank + 1) * self.tp_size + dp_offset,
                        )

                    # send out proxy tensors to the next stage
                    if self.cur_batch:
                        self.pp_group.send_tensor_dict(
                            result.pp_hidden_states_proxy_tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                pp_outputs = next_pp_outputs

            # When the server is idle, self-check and re-init some states
            if server_is_idle:
                self.check_memory()
                self.new_token_ratio = self.init_new_token_ratio

820
821
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
        if self.pp_rank == 0:
            if self.attn_tp_rank == 0:
                recv_reqs = []

                while True:
                    try:
                        recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_req)

                while True:
                    try:
                        recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_rpc)
            else:
                recv_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
841
        else:
842
            if self.attn_tp_rank == 0:
843
                dp_offset = self.attn_dp_rank * self.attn_tp_size
844
845
846
847
848
849
850
851
852
                recv_reqs = point_to_point_pyobj(
                    [],
                    self.pp_rank * self.tp_size + dp_offset,
                    self.world_group.cpu_group,
                    (self.pp_rank - 1) * self.tp_size + dp_offset,
                    self.pp_rank * self.tp_size + dp_offset,
                )
            else:
                recv_reqs = None
853

854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                work_reqs = broadcast_pyobj(
                    work_reqs,
877
                    self.attn_tp_group.rank,
878
                    self.attn_tp_cpu_group,
879
                    src=self.attn_tp_group.ranks[0],
880
881
882
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
883
884
885
886
                    control_reqs,
                    self.tp_group.rank,
                    self.tp_cpu_group,
                    src=self.tp_group.ranks[0],
887
888
889
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
890
891
892
893
894
895
            recv_reqs = broadcast_pyobj(
                recv_reqs,
                self.tp_group.rank,
                self.tp_cpu_group,
                src=self.tp_group.ranks[0],
            )
896
897
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
898
    def process_input_requests(self, recv_reqs: List):
899
        for recv_req in recv_reqs:
900
901
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
Lianmin Zheng's avatar
Lianmin Zheng committed
902
                self.chunked_req is not None or not self.running_batch.is_empty()
903
904
905
906
            ):
                self.return_health_check_ct += 1
                continue

907
            output = self._request_dispatcher(recv_req)
908
            if output is not None:
909
910
911
912
913
                if isinstance(output, RpcReqOutput):
                    if self.recv_from_rpc is not None:
                        self.recv_from_rpc.send_pyobj(output)
                else:
                    self.send_to_tokenizer.send_pyobj(output)
914
915
916
917
918

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
919
        # Create a new request
920
921
922
923
924
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
925
926
927
928
929
930
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

931
932
933
934
            if recv_req.bootstrap_port is None:
                # Use default bootstrap port
                recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port

935
936
937
938
939
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
940
941
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
942
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
943
                stream=recv_req.stream,
944
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
945
                input_embeds=recv_req.input_embeds,
Lianmin Zheng's avatar
Lianmin Zheng committed
946
                custom_logit_processor=recv_req.custom_logit_processor,
947
                return_hidden_states=recv_req.return_hidden_states,
948
                eos_token_ids=self.model_config.hf_eos_token_id,
949
                bootstrap_host=recv_req.bootstrap_host,
950
                bootstrap_port=recv_req.bootstrap_port,
951
                bootstrap_room=recv_req.bootstrap_room,
952
                data_parallel_rank=recv_req.data_parallel_rank,
953
954
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
955

956
957
958
            if self.disaggregation_mode != DisaggregationMode.NULL:
                # Invalid request for disaggregated mode
                if recv_req.bootstrap_room is None:
959
                    error_msg = (
960
961
962
                        f"Invalid request: Disaggregated request received without "
                        f"boostrap room id. {req.rid=}"
                    )
963
964
                    logger.error(error_msg)
                    prepare_abort(req, error_msg)
965
966
967
                    self.stream_output([req], req.return_logprob)
                    return

968
969
970
971
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
972
                req.finished_reason = FINISH_ABORT(
973
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
974
                )
975
                self._add_request_to_queue(req)
976
977
                return
        else:
978
979
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
980
            req = session.create_req(recv_req, self.tokenizer)
981
            if isinstance(req.finished_reason, FINISH_ABORT):
982
                self._add_request_to_queue(req)
983
                return
984

985
        # Handle multimodal inputs
Mick's avatar
Mick committed
986
987
        if recv_req.mm_inputs is not None:
            image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
988
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
989
            req.origin_input_ids = self.pad_input_ids_func(
990
                req.origin_input_ids, image_inputs
991
            )
992
            req.extend_image_inputs(image_inputs)
993

994
            if len(req.origin_input_ids) >= self.max_req_input_len:
995
996
997
998
999
                req.set_finish_with_abort(
                    error_msg=(
                        "Multimodal prompt is too long after expanding multimodal tokens. "
                        f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                    )
1000
                )
1001
                self._add_request_to_queue(req)
1002
1003
                return

1004
        # Validate prompt length
1005
1006
1007
1008
1009
1010
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
1011
            req.set_finish_with_abort(error_msg)
1012
            self._add_request_to_queue(req)
1013
            return
1014

1015
        # Copy more attributes
1016
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
1017
1018
1019
1020
1021
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(req.origin_input_ids) - 1
        else:
            req.logprob_start_len = recv_req.logprob_start_len

1022
        if req.logprob_start_len >= len(req.origin_input_ids):
1023
            error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
1024
            req.logprob_start_len = len(req.origin_input_ids) - 1
1025
            req.set_finish_with_abort(error_msg)
1026
1027
1028
            self._add_request_to_queue(req)
            return

1029
1030
1031
1032
1033
1034
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
1035
            self.max_req_len - len(req.origin_input_ids) - 1,
1036
1037
        )

1038
1039
1040
1041
1042
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
1043
            or req.sampling_params.ebnf is not None
1044
            or req.sampling_params.structural_tag is not None
1045
1046
1047
1048
1049
1050
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
1051
1052
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
1053
1054
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
1055

1056
1057
1058
1059
1060
            value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
            req.grammar = value

            if not cache_hit:
                req.grammar_key = key
1061
                add_to_grammar_queue = True
1062
1063
1064
1065
            else:
                if value is INVALID_GRAMMAR_OBJ:  # We hit a cached invalid grammar.
                    error_msg = f"Invalid grammar request with cache hit: {key=}"
                    req.set_finish_with_abort(error_msg)
1066
1067

        if add_to_grammar_queue:
1068
            req.queue_time_start = time.perf_counter()
1069
1070
            self.grammar_queue.append(req)
        else:
1071
1072
1073
            self._add_request_to_queue(req)

    def _add_request_to_queue(self, req: Req):
1074
        req.queue_time_start = time.perf_counter()
Byron Hsu's avatar
Byron Hsu committed
1075
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
Liangsheng Yin's avatar
Liangsheng Yin committed
1076
            self.disagg_prefill_bootstrap_queue.add(req)
Byron Hsu's avatar
Byron Hsu committed
1077
1078
1079
1080
1081
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.add(req)
        else:
            self.waiting_queue.append(req)

1082
1083
1084
1085
1086
    def _extend_requests_to_queue(self, reqs: List[Req]):
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            self.disagg_prefill_bootstrap_queue.extend(reqs)
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            # If this is a decode server, we put the request to the decode pending prealloc queue
Byron Hsu's avatar
Byron Hsu committed
1087
1088
1089
            self.disagg_decode_prealloc_queue.extend(reqs)
        else:
            self.waiting_queue.extend(reqs)
1090
1091
1092

    def handle_embedding_request(
        self,
1093
        recv_req: TokenizedEmbeddingReqInput,
1094
1095
1096
1097
1098
1099
1100
1101
1102
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

1103
1104
        # Handle multimodal inputs
        if recv_req.image_inputs is not None:
Mick's avatar
Mick committed
1105
            image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
1106
1107
1108
1109
1110
1111
1112
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
            req.origin_input_ids = self.pad_input_ids_func(
                req.origin_input_ids, image_inputs
            )
            req.extend_image_inputs(image_inputs)

            if len(req.origin_input_ids) >= self.max_req_input_len:
1113
1114
1115
1116
1117
                req.set_finish_with_abort(
                    error_msg=(
                        "Multimodal prompt is too long after expanding multimodal tokens. "
                        f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                    )
1118
                )
1119
                self._add_request_to_queue(req)
1120
1121
                return

1122
        # Validate prompts length
1123
        error_msg = validate_input_length(
1124
1125
1126
1127
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
1128
        if error_msg:
1129
            self._add_request_to_queue(req)
1130
            return
1131

1132
1133
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
1134
        self._add_request_to_queue(req)
1135

1136
1137
1138
1139
    def log_prefill_stats(
        self,
        adder: PrefillAdder,
        can_run_list: List[Req],
1140
        running_bs: int,
1141
    ):
1142
1143
        gap_latency = time.perf_counter() - self.last_prefill_stats_tic
        self.last_prefill_stats_tic = time.perf_counter()
Lianmin Zheng's avatar
Lianmin Zheng committed
1144
1145
1146
        self.last_input_throughput = self.num_prefill_tokens / gap_latency
        self.num_prefill_tokens = 0

1147
        num_used = self.max_total_num_tokens - (
1148
1149
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
1150
1151
        )

1152
        num_new_seq = len(can_run_list)
1153
        f = (
1154
            f"Prefill batch. "
1155
            f"#new-seq: {num_new_seq}, "
1156
1157
1158
1159
1160
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
1161
1162
1163
1164

        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
            f += f"#queue-req: {len(self.waiting_queue)}, "
fzyzcjy's avatar
fzyzcjy committed
1165
1166
            f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
            f += f"time: {gap_latency:.2f} "
Liangsheng Yin's avatar
Liangsheng Yin committed
1167
1168
1169
        else:
            f += f"#queue-req: {len(self.waiting_queue)}"

1170
        logger.info(f)
1171
1172

        if self.enable_metrics:
1173
1174
1175
            cache_hit_rate = adder.log_hit_tokens / (
                adder.log_input_tokens + adder.log_hit_tokens
            )
1176
1177
1178
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
1179
1180
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.stats.cache_hit_rate = cache_hit_rate
1181
1182
1183
1184
1185
1186

            total_queue_latency = 0
            for req in can_run_list:
                total_queue_latency += req.queue_time_end - req.queue_time_start
            self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq

1187
            self.metrics_collector.log_stats(self.stats)
1188
        self._publish_kv_events()
1189

1190
1191
1192
    def log_decode_stats(
        self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
    ):
1193
1194
        batch = running_batch or self.running_batch

1195
1196
        gap_latency = time.perf_counter() - self.last_decode_stats_tic
        self.last_decode_stats_tic = time.perf_counter()
1197
1198
        self.last_gen_throughput = self.num_generated_tokens / gap_latency
        self.num_generated_tokens = 0
1199
        num_running_reqs = len(batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1200
        num_used = self.max_total_num_tokens - (
1201
1202
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1203
        )
1204
1205
1206
1207
1208

        if RECORD_STEP_TIME:
            self.step_time_dict[num_running_reqs].append(
                gap_latency / self.server_args.decode_log_interval
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1209

Liangsheng Yin's avatar
Liangsheng Yin committed
1210
1211
1212
1213
1214
1215
1216
        msg = (
            f"Decode batch. "
            f"#running-req: {num_running_reqs}, "
            f"#token: {num_used}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
        )

1217
        if self.spec_algorithm.is_none():
1218
            spec_accept_length = 0
1219
        else:
1220
            spec_accept_length = (
1221
1222
                self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
            )
1223
1224
            self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
            self.cum_spec_accept_count += self.spec_num_total_forward_ct
1225
            self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
1226
1227
1228
1229
1230
1231
            msg += f"accept len: {spec_accept_length:.2f}, "

        if self.disaggregation_mode == DisaggregationMode.DECODE:
            msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "

        msg += (
1232
            f"cuda graph: {can_run_cuda_graph}, "
Liangsheng Yin's avatar
Liangsheng Yin committed
1233
1234
1235
            f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
            f"#queue-req: {len(self.waiting_queue)}"
        )
1236
1237

        logger.info(msg)
1238
1239
1240
1241
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
1242
1243
            self.stats.cache_hit_rate = 0.0
            self.stats.gen_throughput = self.last_gen_throughput
1244
            self.stats.num_queue_reqs = len(self.waiting_queue)
1245
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1246
            self.stats.spec_accept_length = spec_accept_length
1247
            self.metrics_collector.log_stats(self.stats)
1248
        self._publish_kv_events()
1249

Lianmin Zheng's avatar
Lianmin Zheng committed
1250
1251
    def check_memory(self):
        available_size = (
1252
1253
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1254
        )
1255
1256
1257
1258
1259
1260
1261
        protected_size = self.tree_cache.protected_size()
        memory_leak = available_size != (
            self.max_total_num_tokens
            if not self.enable_hierarchical_cache
            else self.max_total_num_tokens - protected_size
        )
        if memory_leak:
1262
            msg = (
1263
                "token_to_kv_pool_allocator memory leak detected! "
1264
                f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1265
1266
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1267
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1268
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1269
1270

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
1271
            msg = (
1272
                "req_to_token_pool memory leak detected!"
1273
1274
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1275
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1276
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1277

1278
1279
1280
        if (
            self.enable_metrics
            and self.attn_tp_rank == 0
1281
            and time.perf_counter() > self.metrics_collector.last_log_time + 30
1282
1283
1284
        ):
            # During idle time, also collect metrics every 30 seconds.
            num_used = self.max_total_num_tokens - (
1285
                self.token_to_kv_pool_allocator.available_size()
1286
1287
                + self.tree_cache.evictable_size()
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1288
            num_running_reqs = len(self.running_batch.reqs)
1289
1290
1291
1292
1293
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
1294
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1295
            self.metrics_collector.log_stats(self.stats)
1296
        self._publish_kv_events()
1297

1298
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1299
        # Merge the prefill batch into the running batch
1300
1301
1302
1303
1304
1305
1306
1307
        chunked_req_to_exclude = set()
        if self.chunked_req:
            # Move the chunked request out of the batch so that we can merge
            # only finished requests to running_batch.
            chunked_req_to_exclude.add(self.chunked_req)
            self.tree_cache.cache_unfinished_req(self.chunked_req)
            # chunked request keeps its rid but will get a new req_pool_idx
            self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
Lianmin Zheng's avatar
Lianmin Zheng committed
1308
        if self.last_batch and self.last_batch.forward_mode.is_extend():
1309
1310
1311
1312
            if self.last_batch.chunked_req is not None:
                # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
                # We need to discard it.
                chunked_req_to_exclude.add(self.last_batch.chunked_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1313

1314
            # Filter batch
1315
            last_bs = self.last_batch.batch_size()
1316
1317
1318
            self.last_batch.filter_batch(
                chunked_req_to_exclude=list(chunked_req_to_exclude)
            )
1319
            if self.last_batch.batch_size() < last_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1320
                self.running_batch.batch_is_full = False
1321

1322
            # Merge the new batch into the running batch
1323
            if not self.last_batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1324
                if self.running_batch.is_empty():
1325
1326
                    self.running_batch = self.last_batch
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1327
                    # Merge running_batch with prefill batch
1328
                    self.running_batch.merge_batch(self.last_batch)
1329

1330
1331
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
1332
1333
1334
1335
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
1336
            if not self.running_batch.is_empty():
1337
                self.running_batch = self.update_running_batch(self.running_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1338
1339
1340
                ret = self.running_batch if not self.running_batch.is_empty() else None
            else:
                ret = None
1341

1342
        # Handle DP attention
1343
        if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
Lianmin Zheng's avatar
Lianmin Zheng committed
1344
            ret, _ = self.prepare_dp_attn_batch(ret)
1345
1346

        return ret
1347

1348
1349
1350
1351
1352
1353
    def get_num_allocatable_reqs(self, running_bs):
        res = global_server_args_dict["max_micro_batch_size"] - running_bs
        if self.pp_size > 1:
            res = min(res, self.req_to_token_pool.available_size())
        return res

Lianmin Zheng's avatar
Lianmin Zheng committed
1354
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
1355
        # Check if the grammar is ready in the grammar queue
1356
        if self.grammar_queue:
1357
            self.move_ready_grammar_requests()
1358

Lianmin Zheng's avatar
Lianmin Zheng committed
1359
1360
        # Handle the cases where prefill is not allowed
        if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1361
            self.running_batch.batch_is_full or len(self.waiting_queue) == 0
1362
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1363
1364
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1365
        running_bs = len(self.running_batch.reqs)
1366
        # Ignore the check if self.chunked_req is not None.
1367
1368
1369
1370
1371
        # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
        # as the space for the chunked request has just been released.
        # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
        # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
        if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
Lianmin Zheng's avatar
Lianmin Zheng committed
1372
            self.running_batch.batch_is_full = True
1373
1374
            return None

1375
1376
1377
1378
1379
        if self.enable_hierarchical_cache:
            # check for completion of hierarchical cache activities to release memory
            self.tree_cache.writing_check()
            self.tree_cache.loading_check()

1380
1381
1382
        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
1383
        # Prefill policy
1384
1385
        adder = PrefillAdder(
            self.tree_cache,
1386
            self.token_to_kv_pool_allocator,
1387
1388
1389
1390
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1391
            running_bs if self.is_mixed_chunk else 0,
1392
1393
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1394
        if self.chunked_req is not None:
1395
1396
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1397

Lianmin Zheng's avatar
Lianmin Zheng committed
1398
        if self.lora_paths:
Lianmin Zheng's avatar
Lianmin Zheng committed
1399
1400
            lora_set = set([req.lora_path for req in self.running_batch.reqs])

1401
        # Get requests from the waiting queue to a new prefill batch
1402
1403
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1404
                self.lora_paths
1405
1406
1407
1408
1409
1410
1411
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1412
                self.running_batch.batch_is_full = True
1413
1414
                break

1415
            if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
Lianmin Zheng's avatar
Lianmin Zheng committed
1416
                self.running_batch.batch_is_full = True
1417
                break
1418

Byron Hsu's avatar
Byron Hsu committed
1419
1420
1421
1422
1423
1424
1425
            if self.disaggregation_mode == DisaggregationMode.PREFILL:
                # In prefill mode, prealloc queue and transfer queue can also take memory,
                # so we need to check if the available size for the actual available size.
                if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
                    self.running_batch.batch_is_full = True
                    break

1426
1427
1428
1429
            req.init_next_round_input(
                None if prefix_computed else self.tree_cache,
                self.enable_hierarchical_cache,
            )
1430

1431
1432
1433
            res = adder.add_one_req(
                req, self.chunked_req, self.enable_hierarchical_cache
            )
1434

1435
1436
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1437
1438
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
Lianmin Zheng's avatar
Lianmin Zheng committed
1439
1440
                        self.running_batch.batch_is_full = len(
                            adder.can_run_list
1441
                        ) > 0 or (not self.running_batch.is_empty())
1442
                    else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1443
                        self.running_batch.batch_is_full = True
1444
1445
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1446
        # Update waiting queue
1447
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1448
1449
        if len(can_run_list) == 0:
            return None
1450
1451
1452
1453

        if self.enable_metrics:
            # only record queue time when enable_metrics is True to avoid overhead
            for req in can_run_list:
1454
                req.queue_time_end = time.perf_counter()
1455

Lianmin Zheng's avatar
Lianmin Zheng committed
1456
1457
1458
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1459

1460
        if self.enable_hierarchical_cache:
1461
            self.tree_cache.ready_to_load_cache()
1462

1463
1464
1465
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1466

1467
1468
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1469

1470
        # Print stats
1471
        if self.attn_tp_rank == 0:
1472
            self.log_prefill_stats(adder, can_run_list, running_bs)
1473

Lianmin Zheng's avatar
Lianmin Zheng committed
1474
        # Create a new batch
1475
1476
1477
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
1478
            self.token_to_kv_pool_allocator,
1479
            self.tree_cache,
1480
            self.model_config,
1481
            self.enable_overlap,
1482
            self.spec_algorithm,
1483
            self.server_args.enable_custom_logit_processor,
1484
            chunked_req=self.chunked_req,
1485
        )
1486
        new_batch.prepare_for_extend()
1487

Lianmin Zheng's avatar
Lianmin Zheng committed
1488
        # Mixed-style chunked prefill
1489
1490
        if (
            self.is_mixed_chunk
Lianmin Zheng's avatar
Lianmin Zheng committed
1491
            and not self.running_batch.is_empty()
1492
1493
1494
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1495
1496
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1497
                self.running_batch.prepare_for_decode()
1498
1499
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1500
1501
1502
            self.running_batch = ScheduleBatch(
                reqs=[], batch_is_full=self.running_batch.batch_is_full
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1503
1504
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1505
1506
1507

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1508
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1509
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1510
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1511

1512
1513
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1514
1515
            batch.batch_is_full = False
            return batch
1516

Lianmin Zheng's avatar
Lianmin Zheng committed
1517
        # Check if decode out of memory
1518
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1519
            TEST_RETRACT and batch.batch_size() > 10
1520
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1521
1522
            old_ratio = self.new_token_ratio

1523
            retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
Lianmin Zheng's avatar
Lianmin Zheng committed
1524
            self.new_token_ratio = new_token_ratio
1525

Lianmin Zheng's avatar
Lianmin Zheng committed
1526
            logger.info(
1527
                "KV cache pool is full. Retract requests. "
Lianmin Zheng's avatar
Lianmin Zheng committed
1528
1529
1530
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
1531
            self._extend_requests_to_queue(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1532
1533
        else:
            self.new_token_ratio = max(
1534
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1535
1536
1537
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1538
        if batch.batch_size() < initial_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1539
            batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1540
1541

        # Update batch tensors
1542
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1543
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1544

1545
1546
1547
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1548
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1549
1550
        self.forward_ct += 1

1551
1552
        # Whether to run the profiler
        self._profile_batch_predicate(batch)
1553
1554
1555
1556
        if self.forward_sleep_time is not None:
            logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
            time.sleep(self.forward_sleep_time)

1557
        # Run forward
1558
        if self.is_generation:
1559
1560
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
1561
                if self.pp_group.is_last_rank:
1562
                    logits_output, next_token_ids, can_run_cuda_graph = (
1563
1564
1565
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
                else:
1566
                    pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
1567
1568
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
1569
                bid = model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
1570
            else:
1571
1572
1573
                (
                    logits_output,
                    next_token_ids,
1574
                    bid,
1575
                    num_accepted_tokens,
1576
                    can_run_cuda_graph,
1577
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
1578
1579
1580
                bs = batch.batch_size()
                self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
                self.spec_num_total_forward_ct += bs
1581
                self.num_generated_tokens += num_accepted_tokens
1582
1583
1584

            if self.pp_group.is_last_rank:
                batch.output_ids = next_token_ids
1585

1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
            if batch.return_logprob:
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_input_len_per_req = None
                extend_logprob_start_len_per_req = None

1598
            ret = GenerationBatchResult(
1599
1600
1601
1602
1603
1604
1605
                logits_output=logits_output if self.pp_group.is_last_rank else None,
                pp_hidden_states_proxy_tensors=(
                    pp_hidden_states_proxy_tensors
                    if not self.pp_group.is_last_rank
                    else None
                ),
                next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
1606
1607
                extend_input_len_per_req=extend_input_len_per_req,
                extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1608
                bid=bid,
1609
                can_run_cuda_graph=can_run_cuda_graph,
1610
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1611
1612
1613
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1614
1615
1616
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1617
        return ret
Chayenne's avatar
Chayenne committed
1618

1619
1620
1621
1622
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
1623
        launch_done: Optional[threading.Event] = None,
1624
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1625
        if batch.forward_mode.is_decode():
1626
            self.process_batch_result_decode(batch, result, launch_done)
1627
        elif batch.forward_mode.is_extend():
1628
            self.process_batch_result_prefill(batch, result, launch_done)
1629
1630
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1631
                self.tp_worker.resolve_last_batch_result(launch_done)
1632
                self.set_next_batch_sampling_info_done(batch)
1633
        elif batch.forward_mode.is_dummy_first():
1634
            self.set_next_batch_sampling_info_done(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1635

1636
1637
1638
1639
1640
1641
1642
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

1643
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
1644
1645
1646
1647
        return self.prepare_dp_attn_batch_raw(
            local_batch,
            dp_size=self.server_args.dp_size,
            attn_tp_size=self.attn_tp_size,
1648
            moe_dense_tp_size=self.server_args.moe_dense_tp_size,
1649
1650
1651
1652
1653
            tp_cpu_group=self.tp_cpu_group,
            get_idle_batch=self.get_idle_batch,
            disable_cuda_graph=self.server_args.disable_cuda_graph,
            spec_algorithm=self.spec_algorithm,
            speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1654
1655
1656
            enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
            enable_deepep_moe=self.server_args.enable_deepep_moe,
            deepep_mode=DeepEPMode[self.server_args.deepep_mode],
1657
1658
1659
1660
1661
1662
1663
        )

    @staticmethod
    def prepare_dp_attn_batch_raw(
        local_batch: ScheduleBatch,
        dp_size,
        attn_tp_size: int,
1664
        moe_dense_tp_size: Optional[int],
1665
1666
1667
1668
1669
        tp_cpu_group,
        get_idle_batch,
        disable_cuda_graph: bool,
        spec_algorithm,
        speculative_num_draft_tokens,
1670
1671
1672
        enable_two_batch_overlap: bool,
        enable_deepep_moe: bool,
        deepep_mode: DeepEPMode,
1673
    ):
1674
1675
1676
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
1677
            num_tokens_for_logprob = 0
1678
1679
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
1680
1681
            if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
                num_tokens = num_tokens * speculative_num_draft_tokens
1682
            num_tokens_for_logprob = num_tokens
1683
1684
        else:
            num_tokens = local_batch.extend_num_tokens
1685
            num_tokens_for_logprob = sum(
Lianmin Zheng's avatar
Lianmin Zheng committed
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
                [
                    # We should have at least 1 token for sample in every case.
                    max(extend_len - logprob_start_len, 1)
                    for logprob_start_len, extend_len in zip(
                        local_batch.extend_logprob_start_lens, local_batch.extend_lens
                    )
                ]
            )

        if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
            can_cuda_graph = 1
        else:
            can_cuda_graph = 0

1700
        if not spec_algorithm.is_none():
1701
            # TODO(sang): Support cuda graph when idle batch is there.
Lianmin Zheng's avatar
Lianmin Zheng committed
1702
1703
            if local_batch is None or local_batch.forward_mode.is_idle():
                can_cuda_graph = 0
1704

Lianmin Zheng's avatar
Lianmin Zheng committed
1705
1706
1707
        is_extend_in_batch = (
            local_batch.forward_mode.is_extend() if local_batch else False
        )
1708
1709
1710

        tbo_preparer = TboDPAttentionPreparer()

Lianmin Zheng's avatar
Lianmin Zheng committed
1711
1712
1713
1714
        local_info = torch.tensor(
            [
                num_tokens,
                can_cuda_graph,
1715
                num_tokens_for_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
1716
                is_extend_in_batch,
1717
1718
1719
1720
1721
1722
                *tbo_preparer.prepare_all_gather(
                    local_batch,
                    deepep_mode,
                    enable_deepep_moe,
                    enable_two_batch_overlap,
                ),
Lianmin Zheng's avatar
Lianmin Zheng committed
1723
1724
1725
1726
            ],
            dtype=torch.int64,
        )
        global_info = torch.empty(
1727
            (dp_size, attn_tp_size, 6),
Lianmin Zheng's avatar
Lianmin Zheng committed
1728
1729
            dtype=torch.int64,
        )
1730
        torch.distributed.all_gather_into_tensor(
Lianmin Zheng's avatar
Lianmin Zheng committed
1731
1732
            global_info.flatten(),
            local_info,
1733
            group=tp_cpu_group,
1734
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1735
1736
1737
1738
        global_num_tokens = global_info[:, 0, 0].tolist()
        can_cuda_graph = min(global_info[:, 0, 1].tolist())
        global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
        is_extend_in_batch = global_info[:, 0, 3].tolist()
1739

1740
1741
1742
1743
        tbo_split_seq_index, global_forward_mode = tbo_preparer.compute_output(
            global_info[:, :, 4:6]
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1744
        if local_batch is None and max(global_num_tokens) > 0:
1745
            local_batch = get_idle_batch()
1746
1747

        if local_batch is not None:
1748
1749
1750
1751
1752
1753
1754
1755
1756
            # TODO: handle the case when moe_dense_tp_size != 1
            if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
                local_batch.global_num_tokens = [num_tokens]
                local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
            else:
                local_batch.global_num_tokens = global_num_tokens
                local_batch.global_num_tokens_for_logprob = (
                    global_num_tokens_for_logprob
                )
1757
1758
            local_batch.tbo_split_seq_index = tbo_split_seq_index
            local_batch.global_forward_mode = global_forward_mode
1759

1760
            # Check forward mode for cuda graph
1761
            if not disable_cuda_graph:
Lianmin Zheng's avatar
Lianmin Zheng committed
1762
                local_batch.can_run_dp_cuda_graph = can_cuda_graph
1763

Lianmin Zheng's avatar
Lianmin Zheng committed
1764
        return local_batch, any(is_extend_in_batch)
1765
1766
1767
1768
1769

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
1770
            self.token_to_kv_pool_allocator,
1771
1772
1773
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
1774
            self.spec_algorithm,
1775
            self.server_args.enable_custom_logit_processor,
1776
1777
1778
1779
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1780
1781
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
1782

1783
        num_ready_reqs = 0
1784
        num_timeout_reqs = 0
1785
1786
        for req in self.grammar_queue:
            try:
1787
1788
1789
                if req.finished():  # It is aborted by AbortReq
                    num_ready_reqs += 1
                    continue
1790
                req.grammar = req.grammar.result(timeout=0.03)
1791
1792
1793
1794
1795
                self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
                if req.grammar is INVALID_GRAMMAR_OBJ:
                    req.set_finish_with_abort(
                        f"Invalid grammar request: {req.grammar_key=}"
                    )
1796
1797
                num_ready_reqs += 1
            except futures._base.TimeoutError:
1798
                req.grammar_wait_ct += 1
1799
1800
                # NOTE(lianmin): this timeout is the waiting time of the above line. It is
                # not the waiting time from it enters the grammar queue.
1801
                if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
1802
                    num_timeout_reqs = 1
1803
1804
                break

1805
        if self.server_args.enable_dp_attention:
1806
1807
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
1808
        else:
1809
1810
1811
1812
1813
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
1814
            tensor = torch.tensor([num_ready_reqs, num_timeout_reqs], dtype=torch.int32)
1815
1816
1817
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
1818
            num_ready_reqs_max, num_timeout_reqs_max = tensor.tolist()
1819

1820
            for i in range(num_ready_reqs, num_ready_reqs_max):
1821
                req = self.grammar_queue[i]
1822
1823
                if req.finished():  # It is aborted by AbortReq
                    continue
1824
                req.grammar = req.grammar.result()
1825
1826
1827
1828
1829
1830
1831
1832
                self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
                if req.grammar is INVALID_GRAMMAR_OBJ:
                    req.set_finish_with_abort(
                        f"Invalid grammar request: {req.grammar_key=}"
                    )
        else:
            num_ready_reqs_max = num_ready_reqs
            num_timeout_reqs_max = num_timeout_reqs
1833

1834
1835
1836
1837
1838
1839
1840
        for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
            req = self.grammar_queue[i]
            req.grammar.cancel()
            error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
            req.set_finish_with_abort(error_msg)
            self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
        num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
1841

1842
        self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1843
1844
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1845
1846
1847
1848
1849
1850
1851
    def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
        if batch.next_batch_sampling_info:
            if batch.next_batch_sampling_info.grammars is not None:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
                self.current_stream.synchronize()
            batch.next_batch_sampling_info.sampling_info_done.set()

1852
1853
1854
    def watchdog_thread(self):
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
        self.watchdog_last_forward_ct = 0
1855
        self.watchdog_last_time = time.perf_counter()
1856
1857

        while True:
1858
            current = time.perf_counter()
1859
1860
1861
1862
1863
1864
1865
1866
1867
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if current > self.watchdog_last_time + self.watchdog_timeout:
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)

Lianmin Zheng's avatar
Lianmin Zheng committed
1868
1869
1870
1871
1872
1873
1874
1875
1876
        if not disable_request_logging():
            # Print batch size and memory pool info to check whether there are de-sync issues.
            logger.error(
                f"{self.cur_batch.batch_size()=}, "
                f"{self.cur_batch.reqs=}, "
                f"{self.token_to_kv_pool_allocator.available_size()=}, "
                f"{self.tree_cache.evictable_size()=}, "
            )

1877
        pyspy_dump_schedulers()
Lianmin Zheng's avatar
Lianmin Zheng committed
1878
        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
1879
1880
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
1881
1882

        # Wait for some time so that the parent process can print the error.
1883
1884
1885
        time.sleep(5)
        self.parent_process.send_signal(signal.SIGQUIT)

1886
1887
1888
    def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
        success = self.flush_cache()
        return FlushCacheReqOutput(success=success)
1889

1890
    def flush_cache(self):
1891
        """Flush the memory pool and cache."""
1892
1893
1894
1895
1896
        if (
            len(self.waiting_queue) == 0
            and self.running_batch.is_empty()
            and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
        ):
1897
1898
            self.cur_batch = None
            self.last_batch = None
1899
            self.tree_cache.reset()
1900
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1901
                self.grammar_backend.reset()
1902
            self.req_to_token_pool.clear()
1903
            self.token_to_kv_pool_allocator.clear()
1904
1905
1906

            if not self.spec_algorithm.is_none():
                self.draft_worker.model_runner.req_to_token_pool.clear()
1907
                self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
1908
1909
1910
1911
1912

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
1913
1914
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
1915
1916
1917
1918
1919
1920
1921
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
1922
                f"#running-req: {len(self.running_batch.reqs)}"
1923
1924
1925
1926
            )
            if_success = False
        return if_success

Liangsheng Yin's avatar
Liangsheng Yin committed
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
    def get_load(self):
        # TODO(lsyin): use dynamically maintained num_waiting_tokens
        load = (
            self.max_total_num_tokens
            - self.token_to_kv_pool_allocator.available_size()
            - self.tree_cache.evictable_size()
        )
        load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            load += sum(
                len(req.origin_input_ids)
                for req in self.disagg_prefill_bootstrap_queue.queue
            )
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            load += sum(
                len(req.req.origin_input_ids)
                for req in self.disagg_decode_prealloc_queue.queue
            )

        return load

1948
1949
1950
1951
1952
1953
1954
1955
1956
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )
        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
Liangsheng Yin's avatar
Liangsheng Yin committed
1957
1958
1959
1960

        ret["load"] = self.get_load()

        return GetInternalStateReqOutput(internal_state=ret)
1961
1962
1963
1964
1965

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
1966
                "max_micro_batch_size",
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
1977
1978
1979
1980
1981
1982
1983
1984
            elif k == "max_micro_batch_size" and (
                v > self.max_running_requests // self.pp_size or v < 1
            ):
                logging.warning(
                    f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
                )
                if_success = False
                break
1985
1986
1987
1988
1989
1990
1991
1992
1993
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
1994
            logger.info(f"Global server args updated! {global_server_args_dict=}")
1995
1996
1997
1998
1999
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
    def handle_rpc_request(self, recv_req: RpcReqInput):
        # Handle RPC requests
        logger.info(
            f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
        )

        success = True
        exec = None
        try:
            func = getattr(self, recv_req.method)
            func(recv_req.parameters)
        except Exception as e:
            success = False
            exec = e
            logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")

        barrier()
        return RpcReqOutput(success, "" if not exec else str(exec))

    def save_remote_model(self, params):
        url = params["url"]

2022
        worker = self.tp_worker.worker
2023
2024
2025
2026

        worker.model_runner.save_remote_model(url)

    def save_sharded_model(self, params):
2027
        worker = self.tp_worker.worker
2028
2029
2030
2031
2032
2033
2034

        worker.model_runner.save_sharded_model(
            path=params["path"],
            pattern=params["pattern"],
            max_size=params["max_size"],
        )

2035
2036
    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
Lianmin Zheng's avatar
Lianmin Zheng committed
2037
        to_del = []
2038
        for i, req in enumerate(self.waiting_queue):
Lianmin Zheng's avatar
Lianmin Zheng committed
2039
2040
            if req.rid.startswith(recv_req.rid):
                to_del.append(i)
2041

Lianmin Zheng's avatar
Lianmin Zheng committed
2042
        # Sort in reverse order to avoid index issues when deleting
Lianmin Zheng's avatar
Lianmin Zheng committed
2043
        for i in reversed(to_del):
2044
2045
2046
            # Abort method 1: directly pop from the queue
            # This only works for requests that have not started anything.
            # We still need to send something back to TokenizerManager to clean up the state.
Lianmin Zheng's avatar
Lianmin Zheng committed
2047
            req = self.waiting_queue.pop(i)
Lianmin Zheng's avatar
Lianmin Zheng committed
2048
            self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
2049
            logger.debug(f"Abort queued request. {req.rid=}")
2050

2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
        # Delete the requests in the grammar queue
        for req in self.grammar_queue:
            # Abort method 2: call `set_finish_with_abort`
            # The request will still run one prefill forward pass.
            # In this case, we change the input_ids to be only one token to make this prefill cheap.
            if req.rid.startswith(recv_req.rid):
                logger.debug(f"Abort grammar queue request. {req.rid=}")
                req.grammar.cancel()
                req.set_finish_with_abort("Aborted by AbortReq.")

2061
        # Delete requests in the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
2062
2063
2064
2065
2066
2067
        if self.cur_batch is self.running_batch or self.cur_batch is None:
            reqs = self.running_batch.reqs
        else:
            reqs = self.running_batch.reqs + self.cur_batch.reqs

        for req in reqs:
Lianmin Zheng's avatar
Lianmin Zheng committed
2068
            if req.rid.startswith(recv_req.rid) and not req.finished():
2069
2070
2071
                # Abort method 3: set `to_abort=True`
                # The request will still run one decode forward pass.
                # Then we reuse all existing code to clean up the KV cache allocation.
Lianmin Zheng's avatar
Lianmin Zheng committed
2072
2073
                logger.debug(f"Abort running request. {req.rid=}")
                req.to_abort = True
2074

2075
2076
2077
    def _pause_engine(self) -> Tuple[List[Req], int]:
        raise NotImplementedError()

Chayenne's avatar
Chayenne committed
2078
2079
2080
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
2081
2082
2083
2084
2085
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
2086
        return UpdateWeightFromDiskReqOutput(success, message, 0)
2087

2088
2089
2090
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
2091
        return InitWeightsUpdateGroupReqOutput(success, message)
2092
2093

    def update_weights_from_distributed(
2094
2095
2096
        self,
        recv_req: UpdateWeightsFromDistributedReqInput,
    ) -> Tuple[bool, str]:
2097
2098
2099
2100
2101
2102
2103
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
2104
        return UpdateWeightsFromDistributedReqOutput(success, message)
2105

2106
2107
2108
2109
2110
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        """Update the online model parameter from tensors."""
        success, message = self.tp_worker.update_weights_from_tensor(recv_req)
        # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
        if success:
2111
2112
2113
            if recv_req.flush_cache:
                flash_cache_success = self.flush_cache()
                assert flash_cache_success, "Cache flush failed after updating weights"
2114
2115
        else:
            logger.error(message)
2116
        return UpdateWeightsFromTensorReqOutput(success, message)
2117

2118
2119
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
2120
        return GetWeightsByNameReqOutput(parameter)
2121

2122
    def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
2123
2124
2125
        self.memory_saver_adapter.check_validity(
            caller_name="release_memory_occupation"
        )
2126
2127
2128
2129
2130
        self.stashed_model_static_state = _export_static_state(
            self.tp_worker.worker.model_runner.model
        )
        self.memory_saver_adapter.pause()
        self.flush_cache()
2131
        return ReleaseMemoryOccupationReqOutput()
2132

2133
    def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
2134
        self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
2135
2136
2137
2138
2139
        self.memory_saver_adapter.resume()
        _import_static_state(
            self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
        )
        del self.stashed_model_static_state
2140
2141
        return ResumeMemoryOccupationReqOutput()

2142
2143
2144
2145
2146
2147
2148
    def slow_down(self, recv_req: SlowDownReqInput):
        t = recv_req.forward_sleep_time
        if t is not None and t <= 0:
            t = None
        self.forward_sleep_time = t
        return SlowDownReqOutput()

2149
    def profile(self, recv_req: ProfileReq):
2150
        if recv_req.type == ProfileReqType.START_PROFILE:
2151
2152
2153
2154
2155
2156
2157
2158
            if recv_req.profile_by_stage:
                return self.init_profile(
                    recv_req.output_dir,
                    recv_req.num_steps,
                    recv_req.activities,
                    recv_req.with_stack,
                    recv_req.record_shapes,
                    recv_req.profile_by_stage,
2159
                    recv_req.profile_id,
2160
2161
2162
2163
2164
2165
2166
2167
2168
                )
            else:
                self.init_profile(
                    recv_req.output_dir,
                    recv_req.num_steps,
                    recv_req.activities,
                    recv_req.with_stack,
                    recv_req.record_shapes,
                    recv_req.profile_by_stage,
2169
                    recv_req.profile_id,
2170
2171
                )
                return self.start_profile(True)
2172
        else:
2173
2174
            return self.stop_profile()

2175
    def init_profile(
2176
2177
2178
2179
        self,
        output_dir: Optional[str],
        num_steps: Optional[int],
        activities: Optional[List[str]],
2180
2181
        with_stack: Optional[bool],
        record_shapes: Optional[bool],
2182
        profile_by_stage: bool,
2183
        profile_id: str,
2184
2185
    ) -> ProfileReqOutput:
        if self.profile_in_progress:
2186
2187
2188
2189
2190
            return ProfileReqOutput(
                success=False,
                message="Profiling is already in progress. Call /stop_profile first.",
            )

2191
2192
        self.profile_by_stage = profile_by_stage

2193
2194
2195
2196
2197
2198
        if output_dir is None:
            output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
        if activities is None:
            activities = ["CPU", "GPU"]

        self.torch_profiler_output_dir = output_dir
2199
2200
        self.torch_profiler_with_stack = with_stack
        self.torch_profiler_record_shapes = record_shapes
2201
        self.profiler_activities = activities
2202
        self.profile_id = profile_id
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222

        if num_steps:
            self.profile_steps = num_steps
            if self.profile_by_stage:
                self.profiler_target_prefill_ct = num_steps
                self.profiler_target_decode_ct = num_steps
                self.profiler_prefill_ct = 0
                self.profiler_decode_ct = 0
            else:
                self.profiler_target_forward_ct = self.forward_ct + num_steps
            # The caller will be notified when reaching profiler_target_forward_ct
        else:
            self.profiler_target_forward_ct = None

        return ProfileReqOutput(success=True, message="Succeeded")

    def start_profile(
        self, stage: Optional[ForwardMode] = None
    ) -> ProfileReqOutput | None:
        stage_str = f" for {stage.__str__()}" if stage else ""
2223
        logger.info(
2224
            f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
2225
2226
        )

2227
2228
2229
2230
        activities = self.profiler_activities
        with_stack = self.torch_profiler_with_stack
        record_shapes = self.torch_profiler_record_shapes

2231
2232
2233
2234
2235
2236
2237
2238
        activity_map = {
            "CPU": torch.profiler.ProfilerActivity.CPU,
            "GPU": torch.profiler.ProfilerActivity.CUDA,
        }
        torchprof_activities = [
            activity_map[a] for a in activities if a in activity_map
        ]

2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
        if "RPD" in activities:
            from rpdTracerControl import rpdTracerControl

            rpdTracerControl.skipCreate()

            self.rpd_profile_path = os.path.join(
                self.torch_profiler_output_dir,
                "rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
            )

            if self.tp_rank == 0:
                import sqlite3

                from rocpd.schema import RocpdSchema

                if os.path.exists("trace.rpd"):
                    os.unlink("trace.rpd")
                schema = RocpdSchema()
                connection = sqlite3.connect("trace.rpd")
                schema.writeSchema(connection)
                connection.commit()
                del connection
            torch.distributed.barrier(self.tp_cpu_group)

            self.rpd_profiler = rpdTracerControl()
            self.rpd_profiler.setPythonTrace(True)
            self.rpd_profiler.start()
            self.rpd_profiler.rangePush("", "rpd profile range", "")
            self.profile_in_progress = True
        elif torchprof_activities:
2269
2270
            self.torch_profiler = torch.profiler.profile(
                activities=torchprof_activities,
2271
2272
                with_stack=with_stack if with_stack is not None else True,
                record_shapes=record_shapes if record_shapes is not None else False,
2273
2274
            )
            self.torch_profiler.start()
2275
            self.profile_in_progress = True
2276
2277
2278

        if "MEM" in activities:
            torch.cuda.memory._record_memory_history(max_entries=100000)
2279
            self.profile_in_progress = True
2280

2281
2282
2283
        if "CUDA_PROFILER" in activities:
            torch.cuda.cudart().cudaProfilerStart()

2284
        return ProfileReqOutput(success=True, message="Succeeded")
2285

2286
2287
2288
2289
    def stop_profile(
        self, stage: Optional[ForwardMode] = None
    ) -> ProfileReqOutput | None:
        if not self.profile_in_progress:
2290
2291
2292
2293
            return ProfileReqOutput(
                success=False,
                message="Profiling is not in progress. Call /start_profile first.",
            )
2294

2295
2296
2297
        if not Path(self.torch_profiler_output_dir).exists():
            Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)

2298
2299
        stage_suffix = f"-{stage.__str__()}" if stage else ""
        logger.info("Stop profiling" + stage_suffix + "...")
2300
2301
2302
2303
2304
        if self.torch_profiler is not None:
            self.torch_profiler.stop()
            self.torch_profiler.export_chrome_trace(
                os.path.join(
                    self.torch_profiler_output_dir,
2305
                    self.profile_id
2306
2307
2308
                    + f"-TP-{self.tp_rank}"
                    + stage_suffix
                    + ".trace.json.gz",
2309
2310
                )
            )
2311
2312
2313
2314
2315
2316
            torch.distributed.barrier(self.tp_cpu_group)

        if self.rpd_profiler is not None:
            self.rpd_profiler.rangePop()
            self.rpd_profiler.stop()
            self.rpd_profiler.flush()
2317

2318
2319
2320
2321
2322
2323
2324
2325
2326
            torch.distributed.barrier(self.tp_cpu_group)
            if self.tp_rank == 0:
                from sglang.srt.utils import rpd_to_chrome_trace

                rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
            self.rpd_profiler = None
            self.rpd_profiler_path = None

        if self.profiler_activities is not None and "MEM" in self.profiler_activities:
2327
            memory_profile_path = os.path.join(
2328
                self.torch_profiler_output_dir,
2329
2330
2331
2332
                str(time.time())
                + f"-TP-{self.tp_rank}-memory"
                + stage_suffix
                + ".pickle",
2333
2334
2335
2336
            )
            torch.cuda.memory._dump_snapshot(memory_profile_path)
            torch.cuda.memory._record_memory_history(enabled=None)

2337
2338
2339
        if "CUDA_PROFILER" in self.profiler_activities:
            torch.cuda.cudart().cudaProfilerStop()

2340
2341
2342
        logger.info(
            "Profiling done. Traces are saved to: %s",
            self.torch_profiler_output_dir,
2343
        )
2344
        self.torch_profiler = None
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
        self.profile_in_progress = False

        return ProfileReqOutput(success=True, message="Succeeded.")

    def _profile_batch_predicate(self, batch):
        if self.profile_by_stage:
            if batch.forward_mode.is_prefill():
                if self.profiler_prefill_ct == 0:
                    self.start_profile(batch.forward_mode)
                self.profiler_prefill_ct += 1
                if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
                    if self.profile_in_progress:
                        self.stop_profile(stage=ForwardMode.EXTEND)
            elif batch.forward_mode.is_decode():
                if self.profiler_decode_ct == 0:
                    if self.profile_in_progress:
                        # force trace flush
                        self.stop_profile(ForwardMode.EXTEND)
                    self.start_profile(batch.forward_mode)
                self.profiler_decode_ct += 1
                if self.profiler_decode_ct > self.profiler_target_decode_ct:
                    if self.profile_in_progress:
                        self.stop_profile(stage=ForwardMode.DECODE)
            else:
                raise RuntimeError("unsupported profile stage")
        else:
            # Check profiler
            if (
                self.profiler_target_forward_ct
                and self.profiler_target_forward_ct <= self.forward_ct
            ):
                self.stop_profile()
2377

2378
2379
    def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
        if recv_req == ExpertDistributionReq.START_RECORD:
2380
            get_global_expert_distribution_recorder().start_record()
2381
        elif recv_req == ExpertDistributionReq.STOP_RECORD:
2382
            get_global_expert_distribution_recorder().stop_record()
2383
        elif recv_req == ExpertDistributionReq.DUMP_RECORD:
2384
            get_global_expert_distribution_recorder().dump_record()
2385
2386
        else:
            raise ValueError("Unrecognized ExpertDistributionReq value")
2387
        return ExpertDistributionReqOutput()
2388

2389
    def open_session(self, recv_req: OpenSessionReqInput):
2390
2391
2392
2393
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
2394
            return OpenSessionReqOutput(session_id, False)
2395
        elif session_id is None:
2396
            logger.warning("session id is None, cannot open.")
2397
            return OpenSessionReqOutput(session_id, False)
2398
2399
2400
2401
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
2402
            return OpenSessionReqOutput(session_id, True)
2403
2404
2405
2406
2407
2408
2409
2410
2411

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

2412
2413
    def get_print_prefix(self):
        prefix = ""
2414
2415
        if self.attn_dp_rank is not None:
            prefix += f" DP{self.attn_dp_rank}"
2416
2417
2418
2419
2420
2421
        if self.server_args.tp_size > 1:
            prefix += f" TP{self.tp_rank}"
        if self.pp_size > 1:
            prefix += f" PP{self.pp_rank}"
        return prefix

2422
2423
2424
2425
2426
2427
2428
    def _publish_kv_events(self):
        if self.enable_kv_cache_events:
            events = self.tree_cache.take_events()
            if events:
                batch = KVEventBatch(ts=time.time(), events=events)
                self.kv_event_publisher.publish(batch)

2429

2430
2431
2432
2433
def is_health_check_generate_req(recv_req):
    return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")


2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
def _export_static_state(model):
    return dict(
        buffers=[
            (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
        ]
    )


def _import_static_state(model, static_params):
    self_named_buffers = dict(model.named_buffers())
    for name, tensor in static_params["buffers"]:
        self_named_buffers[name][...] = tensor


2448
2449
2450
2451
2452
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
2453
    pp_rank: int,
2454
    dp_rank: Optional[int],
2455
    pipe_writer,
2456
):
2457
    # Generate the prefix
2458
2459
2460
2461
2462
2463
2464
    prefix = ""
    if dp_rank is not None:
        prefix += f" DP{dp_rank}"
    if server_args.tp_size > 1:
        prefix += f" TP{tp_rank}"
    if server_args.pp_size > 1:
        prefix += f" PP{pp_rank}"
2465

2466
    # Config the process
2467
    kill_itself_when_parent_died()
2468
    setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
2469
    faulthandler.enable()
2470
    parent_process = psutil.Process().parent()
2471

2472
2473
2474
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
2475

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
2476
    # Configure the logger
2477
    configure_logger(server_args, prefix=prefix)
2478
    suppress_other_loggers()
2479

2480
    # Set cpu affinity to this gpu process
2481
2482
2483
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

2484
2485
2486
2487
    embedding_cache_size = 100
    if "SGLANG_VLM_CACHE_SIZE_MB" in os.environ:
        embedding_cache_size = int(os.environ["SGLANG_VLM_CACHE_SIZE_MB"])
    init_embedding_cache(embedding_cache_size * 1024 * 1024)
2488
    # Create a scheduler and run the event loop
2489
    try:
2490
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
2491
        pipe_writer.send(
Mick's avatar
Mick committed
2492
2493
2494
2495
2496
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2497
        )
Byron Hsu's avatar
Byron Hsu committed
2498
2499
2500
        disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode

        if disaggregation_mode == DisaggregationMode.NULL:
2501
2502
2503
            if server_args.pp_size > 1:
                scheduler.event_loop_pp()
            elif scheduler.enable_overlap:
Byron Hsu's avatar
Byron Hsu committed
2504
2505
2506
2507
                scheduler.event_loop_overlap()
            else:
                scheduler.event_loop_normal()
        elif disaggregation_mode == DisaggregationMode.PREFILL:
2508
2509
2510
2511
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_prefill()
            else:
                scheduler.event_loop_normal_disagg_prefill()
2512

Byron Hsu's avatar
Byron Hsu committed
2513
        elif disaggregation_mode == DisaggregationMode.DECODE:
2514
2515
2516
2517
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_decode()
            else:
                scheduler.event_loop_normal_disagg_decode()
Byron Hsu's avatar
Byron Hsu committed
2518

2519
    except Exception:
2520
2521
2522
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)