scheduler.py 98.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
import time
23
from collections import defaultdict, deque
Lianmin Zheng's avatar
Lianmin Zheng committed
24
from concurrent import futures
25
from dataclasses import dataclass
26
from http import HTTPStatus
27
from types import SimpleNamespace
28
from typing import Dict, List, Optional, Tuple, Union
29

30
import psutil
31
import setproctitle
32
import torch
33
import zmq
34
from torch.distributed import barrier
35

36
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
37
from sglang.srt.configs.model_config import ModelConfig
38
39
40
41
from sglang.srt.constrained.base_grammar_backend import (
    INVALID_GRAMMAR_OBJ,
    create_grammar_backend,
)
Byron Hsu's avatar
Byron Hsu committed
42
43
44
45
46
from sglang.srt.disaggregation.decode import (
    DecodePreallocQueue,
    DecodeTransferQueue,
    SchedulerDisaggregationDecodeMixin,
)
47
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
Byron Hsu's avatar
Byron Hsu committed
48
49
50
51
52
53
from sglang.srt.disaggregation.prefill import (
    PrefillBootstrapQueue,
    SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
54
    MetadataBuffers,
Byron Hsu's avatar
Byron Hsu committed
55
    ReqToMetadataIdxAllocator,
56
    TransferBackend,
57
    prepare_abort,
Byron Hsu's avatar
Byron Hsu committed
58
)
59
from sglang.srt.distributed import get_pp_group, get_world_group
xm:D's avatar
xm:D committed
60
61
62
63
64
from sglang.srt.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
65
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
66
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
67
68
69
from sglang.srt.managers.expert_distribution import (
    get_global_expert_distribution_recorder,
)
70
71
from sglang.srt.managers.io_struct import (
    AbortReq,
72
    CloseSessionReqInput,
73
    ExpertDistributionReq,
74
    ExpertDistributionReqOutput,
75
76
    FlushCacheReqInput,
    FlushCacheReqOutput,
77
78
    GetInternalStateReq,
    GetInternalStateReqOutput,
79
80
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
81
    HealthCheckOutput,
82
83
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
84
85
    OpenSessionReqInput,
    OpenSessionReqOutput,
86
    ProfileReq,
87
88
    ProfileReqOutput,
    ProfileReqType,
89
90
91
92
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
93
94
    RpcReqInput,
    RpcReqOutput,
95
96
    SetInternalStateReq,
    SetInternalStateReqOutput,
97
98
    SlowDownReqInput,
    SlowDownReqOutput,
99
100
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
101
102
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
103
104
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
105
106
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
107
)
108
from sglang.srt.managers.mm_utils import init_embedding_cache
109
110
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
Mick's avatar
Mick committed
111
    MultimodalInputs,
112
113
    Req,
    ScheduleBatch,
114
    global_server_args_dict,
115
)
116
117
118
119
120
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
121
122
123
from sglang.srt.managers.scheduler_output_processor_mixin import (
    SchedulerOutputProcessorMixin,
)
124
from sglang.srt.managers.session_controller import Session
125
from sglang.srt.managers.tp_worker import TpModelWorker
126
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
127
from sglang.srt.managers.utils import validate_input_length
128
from sglang.srt.mem_cache.chunk_cache import ChunkCache
129
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
130
from sglang.srt.mem_cache.radix_cache import RadixCache
131
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
Lianmin Zheng's avatar
Lianmin Zheng committed
132
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
133
from sglang.srt.reasoning_parser import ReasoningParser
134
from sglang.srt.server_args import PortArgs, ServerArgs
135
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
136
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
137
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
138
from sglang.srt.utils import (
139
    DeepEPMode,
140
    DynamicGradMode,
141
142
    broadcast_pyobj,
    configure_logger,
Lianmin Zheng's avatar
Lianmin Zheng committed
143
    disable_request_logging,
144
    get_available_gpu_memory,
145
    get_bool_env_var,
146
    get_zmq_socket,
Lianmin Zheng's avatar
Lianmin Zheng committed
147
    kill_itself_when_parent_died,
148
    point_to_point_pyobj,
149
    pyspy_dump_schedulers,
150
    set_gpu_proc_affinity,
151
152
153
    set_random_seed,
    suppress_other_loggers,
)
154
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
155
156
157

logger = logging.getLogger(__name__)

158
# Test retract decode for debugging purposes
159
160
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
161
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
162

163

164
165
@dataclass
class GenerationBatchResult:
166
167
168
    logits_output: Optional[LogitsProcessorOutput]
    pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
    next_token_ids: Optional[List[int]]
169
170
    extend_input_len_per_req: List[int]
    extend_logprob_start_len_per_req: List[int]
171
    bid: int
172
    can_run_cuda_graph: bool
173
174
175
176
177
178
179
180


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


Byron Hsu's avatar
Byron Hsu committed
181
182
183
184
185
class Scheduler(
    SchedulerOutputProcessorMixin,
    SchedulerDisaggregationDecodeMixin,
    SchedulerDisaggregationPrefillMixin,
):
186
187
188
189
190
191
192
193
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
194
        pp_rank: int,
195
        dp_rank: Optional[int],
196
197
    ):
        # Parse args
198
        self.server_args = server_args
199
        self.tp_rank = tp_rank
200
        self.pp_rank = pp_rank
201
        self.tp_size = server_args.tp_size
202
203
        self.pp_size = server_args.pp_size
        self.dp_size = server_args.dp_size
204
205
206
        self.schedule_policy = server_args.schedule_policy
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
207
        self.enable_overlap = not server_args.disable_overlap_schedule
208
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
209
        self.enable_metrics = server_args.enable_metrics
210
        self.enable_kv_cache_events = server_args.kv_events_config is not None
211
        self.stream_interval = server_args.stream_interval
212
213
214
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
215
216
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
Lianmin Zheng's avatar
Lianmin Zheng committed
217
        self.page_size = server_args.page_size
218
219
        self.dp_size = server_args.dp_size
        self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
220
221
222
223
224
225
226
227
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

228
229
        # Init inter-process communication
        context = zmq.Context(2)
230
        if self.pp_rank == 0 and self.attn_tp_rank == 0:
231
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
232
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
233
            )
234
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
235
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
236
            )
237

238
            if server_args.skip_tokenizer_init:
239
                # Directly send to the TokenizerManager
240
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
241
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
242
243
                )
            else:
244
                # Send to the DetokenizerManager
245
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
246
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
247
                )
248
249
250
251

            self.recv_from_rpc = get_zmq_socket(
                context, zmq.DEALER, port_args.rpc_ipc_name, False
            )
252
        else:
253
            self.recv_from_tokenizer = None
254
            self.recv_from_rpc = None
255
256
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
257
258

        # Init tokenizer
259
        self.init_tokenizer()
260

261
262
263
264
265
266
267
268
269
        # Set reasoning_parser and think_end_id if --reasoning_parser is enabled
        if self.server_args.reasoning_parser and self.tokenizer:
            reasoning_parser = ReasoningParser(
                model_type=self.server_args.reasoning_parser, stream_reasoning=False
            )
            self.tokenizer.think_end_id = self.tokenizer.encode(
                reasoning_parser.detector.think_end_token, add_special_tokens=False
            )[0]

270
271
272
273
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
274

275
        # Launch a tensor parallel worker
276
        if self.enable_overlap:
277
            TpWorkerClass = TpModelWorkerClient
278
279
        else:
            TpWorkerClass = TpModelWorker
280

281
        self.tp_worker = TpWorkerClass(
282
            server_args=server_args,
283
284
            gpu_id=gpu_id,
            tp_rank=tp_rank,
285
            pp_rank=pp_rank,
286
            dp_rank=dp_rank,
287
            nccl_port=port_args.nccl_port,
288
        )
289

290
        # Launch a draft worker for speculative decoding
291
292
293
294
295
296
297
298
299
300
301
302
303
304
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
        else:
            self.draft_worker = None

305
        # Get token and memory info from the model worker
306
307
308
309
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
310
            self.max_req_len,
311
312
            self.max_req_input_len,
            self.random_seed,
313
            self.device,
314
315
316
317
318
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
319
320
321
322
323
324
325
326
        if global_server_args_dict["max_micro_batch_size"] is None:
            global_server_args_dict["max_micro_batch_size"] = max(
                self.max_running_requests // server_args.pp_size, 1
            )

        self.tp_group = self.tp_worker.get_tp_group()
        self.tp_cpu_group = self.tp_group.cpu_group
        self.attn_tp_group = self.tp_worker.get_attention_tp_group()
327
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
328
329
330
        self.pp_group = get_pp_group()
        self.world_group = get_world_group()

331
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
332
        global_server_args_dict.update(worker_global_server_args_dict)
333
        set_random_seed(self.random_seed)
334

335
        # Print debug info
336
        if tp_rank == 0:
337
338
339
            avail_mem = get_available_gpu_memory(
                self.device, self.gpu_id, empty_cache=False
            )
340
341
342
343
344
            logger.info(
                f"max_total_num_tokens={self.max_total_num_tokens}, "
                f"chunked_prefill_size={server_args.chunked_prefill_size}, "
                f"max_prefill_tokens={self.max_prefill_tokens}, "
                f"max_running_requests={self.max_running_requests}, "
345
346
                f"context_len={self.model_config.context_len}, "
                f"available_gpu_mem={avail_mem:.2f} GB"
347
            )
348

Lianmin Zheng's avatar
Lianmin Zheng committed
349
        # Init memory pool and cache
350
        self.init_memory_pool_and_cache()
351
352
353

        # Init running status
        self.waiting_queue: List[Req] = []
354
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
355
        self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
356
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
357
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
358
        # The last forward batch
359
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
360
361
        self.forward_ct = 0
        self.forward_ct_decode = 0
362
        self.num_generated_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
363
        self.num_prefill_tokens = 0
364
365
        self.last_decode_stats_tic = time.perf_counter()
        self.last_prefill_stats_tic = time.perf_counter()
366
        self.return_health_check_ct = 0
367
        self.current_stream = torch.get_device_module(self.device).current_stream()
368
369
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
370
        self.forward_sleep_time = None
371

372
        # Init session info
373
        self.sessions: Dict[str, Session] = {}
374
375
376

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
377
378
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
379
        self.chunked_req = None
380
381
382
383
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
384
        # Init the grammar backend for constrained generation
385
        self.grammar_queue: List[Req] = []
386
        if not server_args.skip_tokenizer_init:
387
388
389
            self.grammar_backend = create_grammar_backend(
                server_args, self.tokenizer, self.model_config.vocab_size
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
390
391
        else:
            self.grammar_backend = None
392

393
        # Init schedule policy and new token estimation
394
        self.policy = SchedulePolicy(
Lianmin Zheng's avatar
Lianmin Zheng committed
395
396
397
            self.schedule_policy,
            self.tree_cache,
            self.enable_hierarchical_cache,
398
        )
399
400
401
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
402
403
        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
404
405
            * server_args.schedule_conservativeness,
            1.0,
406
        )
407
408
409
410
411
412
413
414
415
416
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
417
418
419
420
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
421
        self.parent_process = psutil.Process().parent()
Lianmin Zheng's avatar
Lianmin Zheng committed
422

423
        # Init memory saver
424
425
426
427
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )

428
        # Init profiler
429
430
        self.torch_profiler = None
        self.torch_profiler_output_dir: Optional[str] = None
431
        self.profiler_activities: Optional[List[str]] = None
432
        self.profiler_id: Optional[str] = None
433
        self.profiler_target_forward_ct: Optional[int] = None
434
435
436
437
438
439
440
441
        self.profiler_target_prefill_ct: Optional[int] = None
        self.profiler_target_decode_ct: Optional[int] = None
        self.profiler_prefill_ct: Optional[int] = None
        self.profiler_decode_ct: Optional[int] = None
        self.profile_by_stage: bool = False
        self.profile_steps: Optional[int] = None
        self.profile_in_progress: bool = False
        self.rpd_profiler = None
442

443
        # Init metrics stats
444
        self.init_metrics()
445
        self.init_kv_events(server_args.kv_events_config)
446

447
448
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
449
450
451
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
452
                (FlushCacheReqInput, self.flush_cache_wrapped),
453
                (AbortReq, self.abort_request),
454
455
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
456
457
458
459
460
461
462
463
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
464
465
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
466
                (SlowDownReqInput, self.slow_down),
467
                (ProfileReq, self.profile),
468
                (GetInternalStateReq, self.get_internal_state),
469
                (SetInternalStateReq, self.set_internal_state),
470
                (RpcReqInput, self.handle_rpc_request),
471
                (ExpertDistributionReq, self.expert_distribution_handle),
472
473
474
            ]
        )

Byron Hsu's avatar
Byron Hsu committed
475
476
477
478
479
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
        self.init_disaggregation()

480
481
    def init_tokenizer(self):
        server_args = self.server_args
Lianmin Zheng's avatar
Lianmin Zheng committed
482

483
        self.model_config = ModelConfig.from_server_args(server_args)
484
        self.is_generation = self.model_config.is_generation
485

486
487
488
489
490
491
492
493
494
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if self.model_config.is_multimodal:
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
495
                    use_fast=not server_args.disable_fast_image_processor,
496
                )
xm:D's avatar
xm:D committed
497
                self.tokenizer = get_tokenizer_from_processor(self.processor)
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )

    def init_memory_pool_and_cache(self):
        server_args = self.server_args

        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            self.tp_worker.get_memory_pool()
        )

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
520
                page_size=self.page_size,
521
522
523
524
525
526
            )
        else:
            if self.enable_hierarchical_cache:
                self.tree_cache = HiRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
527
                    tp_cache_group=self.tp_cpu_group,
528
                    page_size=self.page_size,
529
                    hicache_ratio=server_args.hicache_ratio,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
530
531
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
532
533
534
535
536
                )
            else:
                self.tree_cache = RadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
537
                    page_size=self.page_size,
538
                    disable=server_args.disable_radix_cache,
539
                    enable_kv_cache_events=self.enable_kv_cache_events,
540
541
542
543
544
545
546
547
548
549
550
551
                )

        self.decode_mem_cache_buf_multiplier = (
            1
            if self.spec_algorithm.is_none()
            else (
                server_args.speculative_num_draft_tokens
                + (
                    server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                )
            )
552
        )
553
554
555

    def init_metrics(self):
        self.last_gen_throughput: float = 0.0
Lianmin Zheng's avatar
Lianmin Zheng committed
556
        self.last_input_throughput: float = 0.0
557
558
559
560
561
562
563
564
565
566
567
568
569
570
        self.step_time_dict = defaultdict(list)  # Dict[batch size -> step time]
        self.spec_num_total_accepted_tokens = 0
        self.spec_num_total_forward_ct = 0
        self.cum_spec_accept_length = 0
        self.cum_spec_accept_count = 0
        self.stats = SchedulerStats()
        if self.enable_metrics:
            engine_type = "unified"
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    "engine_type": engine_type,
                },
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
571

572
573
574
575
    def init_kv_events(self, kv_events_config: Optional[str]):
        if self.enable_kv_cache_events:
            self.kv_event_publisher = EventPublisherFactory.create(kv_events_config)

Byron Hsu's avatar
Byron Hsu committed
576
    def init_disaggregation(self):
577
578
579
580
        self.transfer_backend = TransferBackend(
            self.server_args.disaggregation_transfer_backend
        )

Byron Hsu's avatar
Byron Hsu committed
581
582
583
584
585
586
587
        if (
            self.disaggregation_mode == DisaggregationMode.DECODE
        ):  # *2 for the headroom.
            buffer_size = (self.req_to_token_pool.size) * 2
            req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
                buffer_size
            )
588
            self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
Byron Hsu's avatar
Byron Hsu committed
589
590
591

            # The decode requests polling kv cache
            self.disagg_decode_transfer_queue = DecodeTransferQueue(
592
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
593
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
594
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
595
596
                scheduler=self,
                tree_cache=self.tree_cache,
Byron Hsu's avatar
Byron Hsu committed
597
598
599
600
601
602
            )

            # The decode requests pending for pre-allocation
            self.disagg_decode_prealloc_queue = DecodePreallocQueue(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Byron Hsu's avatar
Byron Hsu committed
603
604
605
606
607
                draft_token_to_kv_pool=(
                    None
                    if self.draft_worker is None
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
608
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
609
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
610
611
612
                scheduler=self,
                transfer_queue=self.disagg_decode_transfer_queue,
                tree_cache=self.tree_cache,
613
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
614
615
616
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
617
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
618
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
619
620
621
622

            # Metric for pre-allocation
            self.num_tokens_pre_allocated = 0

Byron Hsu's avatar
Byron Hsu committed
623
624
625
626
627
628
        elif self.disaggregation_mode == DisaggregationMode.PREFILL:
            # *2 for the headroom.
            buffer_size = self.max_running_requests * 2
            req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
                buffer_size
            )
629
            self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
Byron Hsu's avatar
Byron Hsu committed
630

Liangsheng Yin's avatar
Liangsheng Yin committed
631
            self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
Byron Hsu's avatar
Byron Hsu committed
632
                token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
Byron Hsu's avatar
Byron Hsu committed
633
634
635
636
637
                draft_token_to_kv_pool=(
                    None
                    if self.draft_worker is None
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
638
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
639
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
640
641
642
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
643
                gloo_group=self.attn_tp_cpu_group,
644
                transfer_backend=self.transfer_backend,
645
                scheduler=self,
Byron Hsu's avatar
Byron Hsu committed
646
647
            )
            # The prefill requests that are in the middle of kv sending
648
            self.disagg_prefill_inflight_queue: List[Req] = []
Byron Hsu's avatar
Byron Hsu committed
649

650
    @DynamicGradMode()
651
    def event_loop_normal(self):
652
        """A normal scheduler loop."""
653
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
654
655
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
656

657
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
658
            self.cur_batch = batch
659
660
661
662

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
663
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
664
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
665
                self.check_memory()
666
                self.new_token_ratio = self.init_new_token_ratio
667
668

            self.last_batch = batch
669

670
    @DynamicGradMode()
Lianmin Zheng's avatar
Lianmin Zheng committed
671
    def event_loop_overlap(self):
672
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
673
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
674
675
676
677
678
679
680

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
681

Lianmin Zheng's avatar
Lianmin Zheng committed
682
            if batch:
683
                batch.launch_done = threading.Event()
Lianmin Zheng's avatar
Lianmin Zheng committed
684
                result = self.run_batch(batch)
685
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
686

687
                if self.last_batch is None:
688
                    # Create a dummy first batch to start the pipeline for overlap schedule.
689
690
691
692
693
694
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
695
                    self.process_batch_result(tmp_batch, None, batch.launch_done)
696

Lianmin Zheng's avatar
Lianmin Zheng committed
697
            if self.last_batch:
698
                # Process the results of the last batch
699
                tmp_batch, tmp_result = self.result_queue.popleft()
700
701
702
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
703
704
705
706
                # NOTE: we should use current launched batch's launch_done event Instead of the last batch's
                self.process_batch_result(
                    tmp_batch, tmp_result, batch.launch_done if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
707
            elif batch is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
708
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
709
                self.check_memory()
710
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
711
712
713

            self.last_batch = batch

714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
    @DynamicGradMode()
    def event_loop_pp(self):
        """A non-overlap scheduler loop for pipeline parallelism."""
        mbs = [None] * self.pp_size
        last_mbs = [None] * self.pp_size
        self.running_mbs = [
            ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
        ]
        bids = [None] * self.pp_size
        pp_outputs: Optional[PPProxyTensors] = None
        while True:
            server_is_idle = True
            for mb_id in range(self.pp_size):
                self.running_batch = self.running_mbs[mb_id]
                self.last_batch = last_mbs[mb_id]

                recv_reqs = self.recv_requests()
                self.process_input_requests(recv_reqs)
                mbs[mb_id] = self.get_next_batch_to_run()
                self.running_mbs[mb_id] = self.running_batch

                self.cur_batch = mbs[mb_id]
                if self.cur_batch:
                    server_is_idle = False
                    result = self.run_batch(self.cur_batch)

740
                # (last rank) send the outputs to the next step
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
                if self.pp_group.is_last_rank:
                    if self.cur_batch:
                        next_token_ids, bids[mb_id] = (
                            result.next_token_ids,
                            result.bid,
                        )
                        pp_outputs = PPProxyTensors(
                            {
                                "next_token_ids": next_token_ids,
                            }
                        )
                        # send the output from the last round to let the next stage worker run post processing
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                # receive outputs and post-process (filter finished reqs) the coming microbatch
                next_mb_id = (mb_id + 1) % self.pp_size
                next_pp_outputs = None
                if mbs[next_mb_id] is not None:
                    next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
                        self.pp_group.recv_tensor_dict(
                            all_gather_group=self.attn_tp_group
                        )
                    )
                    mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
                    output_result = GenerationBatchResult(
                        logits_output=None,
                        pp_hidden_states_proxy_tensors=None,
                        next_token_ids=next_pp_outputs["next_token_ids"],
                        extend_input_len_per_req=None,
                        extend_logprob_start_len_per_req=None,
                        bid=bids[next_mb_id],
775
                        can_run_cuda_graph=result.can_run_cuda_graph,
776
777
778
779
                    )
                    self.process_batch_result(mbs[next_mb_id], output_result)
                    last_mbs[next_mb_id] = mbs[next_mb_id]

780
                # (not last rank)
781
782
783
                if not self.pp_group.is_last_rank:
                    if self.cur_batch:
                        bids[mb_id] = result.bid
784
785
                    # carry the outputs to the next stage
                    # send the outputs from the last round to let the next stage worker run post processing
786
787
788
789
790
791
792
                    if pp_outputs:
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                    # send out reqs to the next stage
793
                    dp_offset = self.attn_dp_rank * self.attn_tp_size
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
                    if self.attn_tp_rank == 0:
                        point_to_point_pyobj(
                            recv_reqs,
                            self.pp_rank * self.tp_size + dp_offset,
                            self.world_group.cpu_group,
                            self.pp_rank * self.tp_size + dp_offset,
                            (self.pp_rank + 1) * self.tp_size + dp_offset,
                        )

                    # send out proxy tensors to the next stage
                    if self.cur_batch:
                        self.pp_group.send_tensor_dict(
                            result.pp_hidden_states_proxy_tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                pp_outputs = next_pp_outputs

            # When the server is idle, self-check and re-init some states
            if server_is_idle:
                self.check_memory()
                self.new_token_ratio = self.init_new_token_ratio

817
818
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
        if self.pp_rank == 0:
            if self.attn_tp_rank == 0:
                recv_reqs = []

                while True:
                    try:
                        recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_req)

                while True:
                    try:
                        recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_rpc)
            else:
                recv_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
838
        else:
839
            if self.attn_tp_rank == 0:
840
                dp_offset = self.attn_dp_rank * self.attn_tp_size
841
842
843
844
845
846
847
848
849
                recv_reqs = point_to_point_pyobj(
                    [],
                    self.pp_rank * self.tp_size + dp_offset,
                    self.world_group.cpu_group,
                    (self.pp_rank - 1) * self.tp_size + dp_offset,
                    self.pp_rank * self.tp_size + dp_offset,
                )
            else:
                recv_reqs = None
850

851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                work_reqs = broadcast_pyobj(
                    work_reqs,
874
                    self.attn_tp_group.rank,
875
                    self.attn_tp_cpu_group,
876
                    src=self.attn_tp_group.ranks[0],
877
878
879
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
880
881
882
883
                    control_reqs,
                    self.tp_group.rank,
                    self.tp_cpu_group,
                    src=self.tp_group.ranks[0],
884
885
886
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
887
888
889
890
891
892
            recv_reqs = broadcast_pyobj(
                recv_reqs,
                self.tp_group.rank,
                self.tp_cpu_group,
                src=self.tp_group.ranks[0],
            )
893
894
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
895
    def process_input_requests(self, recv_reqs: List):
896
        for recv_req in recv_reqs:
897
898
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
Lianmin Zheng's avatar
Lianmin Zheng committed
899
                self.chunked_req is not None or not self.running_batch.is_empty()
900
901
902
903
            ):
                self.return_health_check_ct += 1
                continue

904
            output = self._request_dispatcher(recv_req)
905
            if output is not None:
906
907
908
909
910
                if isinstance(output, RpcReqOutput):
                    if self.recv_from_rpc is not None:
                        self.recv_from_rpc.send_pyobj(output)
                else:
                    self.send_to_tokenizer.send_pyobj(output)
911
912
913
914
915

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
916
        # Create a new request
917
918
919
920
921
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
922
923
924
925
926
927
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

928
929
930
931
            if recv_req.bootstrap_port is None:
                # Use default bootstrap port
                recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port

932
933
934
935
936
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
937
938
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
939
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
940
                stream=recv_req.stream,
941
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
942
                input_embeds=recv_req.input_embeds,
Lianmin Zheng's avatar
Lianmin Zheng committed
943
                custom_logit_processor=recv_req.custom_logit_processor,
944
                return_hidden_states=recv_req.return_hidden_states,
945
                eos_token_ids=self.model_config.hf_eos_token_id,
946
                bootstrap_host=recv_req.bootstrap_host,
947
                bootstrap_port=recv_req.bootstrap_port,
948
                bootstrap_room=recv_req.bootstrap_room,
949
950
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
951

952
953
954
            if self.disaggregation_mode != DisaggregationMode.NULL:
                # Invalid request for disaggregated mode
                if recv_req.bootstrap_room is None:
955
                    error_msg = (
956
957
958
                        f"Invalid request: Disaggregated request received without "
                        f"boostrap room id. {req.rid=}"
                    )
959
960
                    logger.error(error_msg)
                    prepare_abort(req, error_msg)
961
962
963
                    self.stream_output([req], req.return_logprob)
                    return

964
965
966
967
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
968
                req.finished_reason = FINISH_ABORT(
969
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
970
                )
971
                self._add_request_to_queue(req)
972
973
                return
        else:
974
975
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
976
            req = session.create_req(recv_req, self.tokenizer)
977
            if isinstance(req.finished_reason, FINISH_ABORT):
978
                self._add_request_to_queue(req)
979
                return
980

981
        # Handle multimodal inputs
Mick's avatar
Mick committed
982
983
        if recv_req.mm_inputs is not None:
            image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
984
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
985
            req.origin_input_ids = self.pad_input_ids_func(
986
                req.origin_input_ids, image_inputs
987
            )
988
            req.extend_image_inputs(image_inputs)
989

990
            if len(req.origin_input_ids) >= self.max_req_input_len:
991
992
993
994
995
                req.set_finish_with_abort(
                    error_msg=(
                        "Multimodal prompt is too long after expanding multimodal tokens. "
                        f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                    )
996
                )
997
                self._add_request_to_queue(req)
998
999
                return

1000
        # Validate prompt length
1001
1002
1003
1004
1005
1006
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
1007
            req.set_finish_with_abort(error_msg)
1008
            self._add_request_to_queue(req)
1009
            return
1010

1011
        # Copy more attributes
1012
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
1013
1014
1015
1016
1017
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(req.origin_input_ids) - 1
        else:
            req.logprob_start_len = recv_req.logprob_start_len

1018
        if req.logprob_start_len >= len(req.origin_input_ids):
1019
            error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
1020
            req.logprob_start_len = len(req.origin_input_ids) - 1
1021
            req.set_finish_with_abort(error_msg)
1022
1023
1024
            self._add_request_to_queue(req)
            return

1025
1026
1027
1028
1029
1030
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
1031
            self.max_req_len - len(req.origin_input_ids) - 1,
1032
1033
        )

1034
1035
1036
1037
1038
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
1039
            or req.sampling_params.ebnf is not None
1040
            or req.sampling_params.structural_tag is not None
1041
1042
1043
1044
1045
1046
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
1047
1048
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
1049
1050
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
1051

1052
1053
1054
1055
1056
            value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
            req.grammar = value

            if not cache_hit:
                req.grammar_key = key
1057
                add_to_grammar_queue = True
1058
1059
1060
1061
            else:
                if value is INVALID_GRAMMAR_OBJ:  # We hit a cached invalid grammar.
                    error_msg = f"Invalid grammar request with cache hit: {key=}"
                    req.set_finish_with_abort(error_msg)
1062
1063

        if add_to_grammar_queue:
1064
            req.queue_time_start = time.perf_counter()
1065
1066
            self.grammar_queue.append(req)
        else:
1067
1068
1069
            self._add_request_to_queue(req)

    def _add_request_to_queue(self, req: Req):
1070
        req.queue_time_start = time.perf_counter()
Byron Hsu's avatar
Byron Hsu committed
1071
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
Liangsheng Yin's avatar
Liangsheng Yin committed
1072
            self.disagg_prefill_bootstrap_queue.add(req)
Byron Hsu's avatar
Byron Hsu committed
1073
1074
1075
1076
1077
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.add(req)
        else:
            self.waiting_queue.append(req)

1078
1079
1080
1081
1082
    def _extend_requests_to_queue(self, reqs: List[Req]):
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            self.disagg_prefill_bootstrap_queue.extend(reqs)
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            # If this is a decode server, we put the request to the decode pending prealloc queue
Byron Hsu's avatar
Byron Hsu committed
1083
1084
1085
            self.disagg_decode_prealloc_queue.extend(reqs)
        else:
            self.waiting_queue.extend(reqs)
1086
1087
1088

    def handle_embedding_request(
        self,
1089
        recv_req: TokenizedEmbeddingReqInput,
1090
1091
1092
1093
1094
1095
1096
1097
1098
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

1099
1100
        # Handle multimodal inputs
        if recv_req.image_inputs is not None:
Mick's avatar
Mick committed
1101
            image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
1102
1103
1104
1105
1106
1107
1108
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
            req.origin_input_ids = self.pad_input_ids_func(
                req.origin_input_ids, image_inputs
            )
            req.extend_image_inputs(image_inputs)

            if len(req.origin_input_ids) >= self.max_req_input_len:
1109
1110
1111
1112
1113
                req.set_finish_with_abort(
                    error_msg=(
                        "Multimodal prompt is too long after expanding multimodal tokens. "
                        f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                    )
1114
                )
1115
                self._add_request_to_queue(req)
1116
1117
                return

1118
        # Validate prompts length
1119
        error_msg = validate_input_length(
1120
1121
1122
1123
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
1124
        if error_msg:
1125
            self._add_request_to_queue(req)
1126
            return
1127

1128
1129
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
1130
        self._add_request_to_queue(req)
1131

1132
1133
1134
1135
    def log_prefill_stats(
        self,
        adder: PrefillAdder,
        can_run_list: List[Req],
1136
        running_bs: int,
1137
    ):
1138
1139
        gap_latency = time.perf_counter() - self.last_prefill_stats_tic
        self.last_prefill_stats_tic = time.perf_counter()
Lianmin Zheng's avatar
Lianmin Zheng committed
1140
1141
1142
        self.last_input_throughput = self.num_prefill_tokens / gap_latency
        self.num_prefill_tokens = 0

1143
        num_used = self.max_total_num_tokens - (
1144
1145
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
1146
1147
        )

1148
        num_new_seq = len(can_run_list)
1149
        f = (
1150
            f"Prefill batch. "
1151
            f"#new-seq: {num_new_seq}, "
1152
1153
1154
1155
1156
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
1157
1158
1159
1160

        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
            f += f"#queue-req: {len(self.waiting_queue)}, "
fzyzcjy's avatar
fzyzcjy committed
1161
1162
            f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
            f += f"time: {gap_latency:.2f} "
Liangsheng Yin's avatar
Liangsheng Yin committed
1163
1164
1165
        else:
            f += f"#queue-req: {len(self.waiting_queue)}"

1166
        logger.info(f)
1167
1168

        if self.enable_metrics:
1169
1170
1171
            cache_hit_rate = adder.log_hit_tokens / (
                adder.log_input_tokens + adder.log_hit_tokens
            )
1172
1173
1174
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
1175
1176
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.stats.cache_hit_rate = cache_hit_rate
1177
1178
1179
1180
1181
1182

            total_queue_latency = 0
            for req in can_run_list:
                total_queue_latency += req.queue_time_end - req.queue_time_start
            self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq

1183
            self.metrics_collector.log_stats(self.stats)
1184
        self._publish_kv_events()
1185

1186
1187
1188
    def log_decode_stats(
        self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
    ):
1189
1190
        batch = running_batch or self.running_batch

1191
1192
        gap_latency = time.perf_counter() - self.last_decode_stats_tic
        self.last_decode_stats_tic = time.perf_counter()
1193
1194
        self.last_gen_throughput = self.num_generated_tokens / gap_latency
        self.num_generated_tokens = 0
1195
        num_running_reqs = len(batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1196
        num_used = self.max_total_num_tokens - (
1197
1198
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1199
        )
1200
1201
1202
1203
1204

        if RECORD_STEP_TIME:
            self.step_time_dict[num_running_reqs].append(
                gap_latency / self.server_args.decode_log_interval
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1205

Liangsheng Yin's avatar
Liangsheng Yin committed
1206
1207
1208
1209
1210
1211
1212
        msg = (
            f"Decode batch. "
            f"#running-req: {num_running_reqs}, "
            f"#token: {num_used}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
        )

1213
        if self.spec_algorithm.is_none():
1214
            spec_accept_length = 0
1215
        else:
1216
            spec_accept_length = (
1217
1218
                self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
            )
1219
1220
            self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
            self.cum_spec_accept_count += self.spec_num_total_forward_ct
1221
            self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
1222
1223
1224
1225
1226
1227
            msg += f"accept len: {spec_accept_length:.2f}, "

        if self.disaggregation_mode == DisaggregationMode.DECODE:
            msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "

        msg += (
1228
            f"cuda graph: {can_run_cuda_graph}, "
Liangsheng Yin's avatar
Liangsheng Yin committed
1229
1230
1231
            f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
            f"#queue-req: {len(self.waiting_queue)}"
        )
1232
1233

        logger.info(msg)
1234
1235
1236
1237
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
1238
1239
            self.stats.cache_hit_rate = 0.0
            self.stats.gen_throughput = self.last_gen_throughput
1240
            self.stats.num_queue_reqs = len(self.waiting_queue)
1241
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1242
            self.stats.spec_accept_length = spec_accept_length
1243
            self.metrics_collector.log_stats(self.stats)
1244
        self._publish_kv_events()
1245

Lianmin Zheng's avatar
Lianmin Zheng committed
1246
1247
    def check_memory(self):
        available_size = (
1248
1249
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1250
        )
1251
1252
1253
1254
1255
1256
1257
        protected_size = self.tree_cache.protected_size()
        memory_leak = available_size != (
            self.max_total_num_tokens
            if not self.enable_hierarchical_cache
            else self.max_total_num_tokens - protected_size
        )
        if memory_leak:
1258
            msg = (
1259
                "token_to_kv_pool_allocator memory leak detected! "
1260
                f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1261
1262
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1263
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1264
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1265
1266

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
1267
            msg = (
1268
                "req_to_token_pool memory leak detected!"
1269
1270
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1271
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1272
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1273

1274
1275
1276
        if (
            self.enable_metrics
            and self.attn_tp_rank == 0
1277
            and time.perf_counter() > self.metrics_collector.last_log_time + 30
1278
1279
1280
        ):
            # During idle time, also collect metrics every 30 seconds.
            num_used = self.max_total_num_tokens - (
1281
                self.token_to_kv_pool_allocator.available_size()
1282
1283
                + self.tree_cache.evictable_size()
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1284
            num_running_reqs = len(self.running_batch.reqs)
1285
1286
1287
1288
1289
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
1290
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1291
            self.metrics_collector.log_stats(self.stats)
1292
        self._publish_kv_events()
1293

1294
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1295
        # Merge the prefill batch into the running batch
1296
1297
1298
1299
1300
1301
1302
1303
        chunked_req_to_exclude = set()
        if self.chunked_req:
            # Move the chunked request out of the batch so that we can merge
            # only finished requests to running_batch.
            chunked_req_to_exclude.add(self.chunked_req)
            self.tree_cache.cache_unfinished_req(self.chunked_req)
            # chunked request keeps its rid but will get a new req_pool_idx
            self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
Lianmin Zheng's avatar
Lianmin Zheng committed
1304
        if self.last_batch and self.last_batch.forward_mode.is_extend():
1305
1306
1307
1308
            if self.last_batch.chunked_req is not None:
                # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
                # We need to discard it.
                chunked_req_to_exclude.add(self.last_batch.chunked_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1309

1310
            # Filter batch
1311
            last_bs = self.last_batch.batch_size()
1312
1313
1314
            self.last_batch.filter_batch(
                chunked_req_to_exclude=list(chunked_req_to_exclude)
            )
1315
            if self.last_batch.batch_size() < last_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1316
                self.running_batch.batch_is_full = False
1317

1318
            # Merge the new batch into the running batch
1319
            if not self.last_batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1320
                if self.running_batch.is_empty():
1321
1322
                    self.running_batch = self.last_batch
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1323
                    # Merge running_batch with prefill batch
1324
                    self.running_batch.merge_batch(self.last_batch)
1325

1326
1327
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
1328
1329
1330
1331
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
1332
            if not self.running_batch.is_empty():
1333
                self.running_batch = self.update_running_batch(self.running_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1334
1335
1336
                ret = self.running_batch if not self.running_batch.is_empty() else None
            else:
                ret = None
1337

1338
        # Handle DP attention
1339
        if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
Lianmin Zheng's avatar
Lianmin Zheng committed
1340
            ret, _ = self.prepare_dp_attn_batch(ret)
1341
1342

        return ret
1343

1344
1345
1346
1347
1348
1349
    def get_num_allocatable_reqs(self, running_bs):
        res = global_server_args_dict["max_micro_batch_size"] - running_bs
        if self.pp_size > 1:
            res = min(res, self.req_to_token_pool.available_size())
        return res

Lianmin Zheng's avatar
Lianmin Zheng committed
1350
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
1351
        # Check if the grammar is ready in the grammar queue
1352
        if self.grammar_queue:
1353
            self.move_ready_grammar_requests()
1354

Lianmin Zheng's avatar
Lianmin Zheng committed
1355
1356
        # Handle the cases where prefill is not allowed
        if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1357
            self.running_batch.batch_is_full or len(self.waiting_queue) == 0
1358
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1359
1360
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1361
        running_bs = len(self.running_batch.reqs)
1362
        # Ignore the check if self.chunked_req is not None.
1363
1364
1365
1366
1367
        # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
        # as the space for the chunked request has just been released.
        # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
        # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
        if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
Lianmin Zheng's avatar
Lianmin Zheng committed
1368
            self.running_batch.batch_is_full = True
1369
1370
            return None

1371
1372
1373
1374
1375
        if self.enable_hierarchical_cache:
            # check for completion of hierarchical cache activities to release memory
            self.tree_cache.writing_check()
            self.tree_cache.loading_check()

1376
1377
1378
        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
1379
        # Prefill policy
1380
1381
        adder = PrefillAdder(
            self.tree_cache,
1382
            self.token_to_kv_pool_allocator,
1383
1384
1385
1386
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1387
            running_bs if self.is_mixed_chunk else 0,
1388
1389
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1390
        if self.chunked_req is not None:
1391
1392
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1393

Lianmin Zheng's avatar
Lianmin Zheng committed
1394
        if self.lora_paths:
Lianmin Zheng's avatar
Lianmin Zheng committed
1395
1396
            lora_set = set([req.lora_path for req in self.running_batch.reqs])

1397
        # Get requests from the waiting queue to a new prefill batch
1398
1399
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1400
                self.lora_paths
1401
1402
1403
1404
1405
1406
1407
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1408
                self.running_batch.batch_is_full = True
1409
1410
                break

1411
            if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
Lianmin Zheng's avatar
Lianmin Zheng committed
1412
                self.running_batch.batch_is_full = True
1413
                break
1414

Byron Hsu's avatar
Byron Hsu committed
1415
1416
1417
1418
1419
1420
1421
            if self.disaggregation_mode == DisaggregationMode.PREFILL:
                # In prefill mode, prealloc queue and transfer queue can also take memory,
                # so we need to check if the available size for the actual available size.
                if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
                    self.running_batch.batch_is_full = True
                    break

1422
1423
1424
1425
            req.init_next_round_input(
                None if prefix_computed else self.tree_cache,
                self.enable_hierarchical_cache,
            )
1426

1427
1428
1429
            res = adder.add_one_req(
                req, self.chunked_req, self.enable_hierarchical_cache
            )
1430

1431
1432
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1433
1434
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
Lianmin Zheng's avatar
Lianmin Zheng committed
1435
1436
                        self.running_batch.batch_is_full = len(
                            adder.can_run_list
1437
                        ) > 0 or (not self.running_batch.is_empty())
1438
                    else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1439
                        self.running_batch.batch_is_full = True
1440
1441
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1442
        # Update waiting queue
1443
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1444
1445
        if len(can_run_list) == 0:
            return None
1446
1447
1448
1449

        if self.enable_metrics:
            # only record queue time when enable_metrics is True to avoid overhead
            for req in can_run_list:
1450
                req.queue_time_end = time.perf_counter()
1451

Lianmin Zheng's avatar
Lianmin Zheng committed
1452
1453
1454
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1455

1456
        if self.enable_hierarchical_cache:
1457
            self.tree_cache.ready_to_load_cache()
1458

1459
1460
1461
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1462

1463
1464
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1465

1466
        # Print stats
1467
        if self.attn_tp_rank == 0:
1468
            self.log_prefill_stats(adder, can_run_list, running_bs)
1469

Lianmin Zheng's avatar
Lianmin Zheng committed
1470
        # Create a new batch
1471
1472
1473
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
1474
            self.token_to_kv_pool_allocator,
1475
            self.tree_cache,
1476
            self.model_config,
1477
            self.enable_overlap,
1478
            self.spec_algorithm,
1479
            self.server_args.enable_custom_logit_processor,
1480
            chunked_req=self.chunked_req,
1481
        )
1482
        new_batch.prepare_for_extend()
1483

Lianmin Zheng's avatar
Lianmin Zheng committed
1484
        # Mixed-style chunked prefill
1485
1486
        if (
            self.is_mixed_chunk
Lianmin Zheng's avatar
Lianmin Zheng committed
1487
            and not self.running_batch.is_empty()
1488
1489
1490
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1491
1492
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1493
                self.running_batch.prepare_for_decode()
1494
1495
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1496
1497
1498
            self.running_batch = ScheduleBatch(
                reqs=[], batch_is_full=self.running_batch.batch_is_full
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1499
1500
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1501
1502
1503

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1504
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1505
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1506
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1507

1508
1509
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1510
1511
            batch.batch_is_full = False
            return batch
1512

Lianmin Zheng's avatar
Lianmin Zheng committed
1513
        # Check if decode out of memory
1514
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1515
            TEST_RETRACT and batch.batch_size() > 10
1516
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1517
1518
            old_ratio = self.new_token_ratio

1519
            retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
Lianmin Zheng's avatar
Lianmin Zheng committed
1520
            self.new_token_ratio = new_token_ratio
1521

Lianmin Zheng's avatar
Lianmin Zheng committed
1522
            logger.info(
1523
                "KV cache pool is full. Retract requests. "
Lianmin Zheng's avatar
Lianmin Zheng committed
1524
1525
1526
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
1527
            self._extend_requests_to_queue(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1528
1529
        else:
            self.new_token_ratio = max(
1530
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1531
1532
1533
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1534
        if batch.batch_size() < initial_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1535
            batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1536
1537

        # Update batch tensors
1538
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1539
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1540

1541
1542
1543
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1544
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1545
1546
        self.forward_ct += 1

1547
1548
        # Whether to run the profiler
        self._profile_batch_predicate(batch)
1549
1550
1551
1552
        if self.forward_sleep_time is not None:
            logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
            time.sleep(self.forward_sleep_time)

1553
        # Run forward
1554
        if self.is_generation:
1555
1556
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
1557
                if self.pp_group.is_last_rank:
1558
                    logits_output, next_token_ids, can_run_cuda_graph = (
1559
1560
1561
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
                else:
1562
                    pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
1563
1564
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
1565
                bid = model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
1566
            else:
1567
1568
1569
                (
                    logits_output,
                    next_token_ids,
1570
                    bid,
1571
                    num_accepted_tokens,
1572
                    can_run_cuda_graph,
1573
1574
1575
1576
1577
1578
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
                self.spec_num_total_accepted_tokens += (
                    num_accepted_tokens + batch.batch_size()
                )
                self.spec_num_total_forward_ct += batch.batch_size()
                self.num_generated_tokens += num_accepted_tokens
1579
1580
1581

            if self.pp_group.is_last_rank:
                batch.output_ids = next_token_ids
1582

1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
            if batch.return_logprob:
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_input_len_per_req = None
                extend_logprob_start_len_per_req = None

1595
            ret = GenerationBatchResult(
1596
1597
1598
1599
1600
1601
1602
                logits_output=logits_output if self.pp_group.is_last_rank else None,
                pp_hidden_states_proxy_tensors=(
                    pp_hidden_states_proxy_tensors
                    if not self.pp_group.is_last_rank
                    else None
                ),
                next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
1603
1604
                extend_input_len_per_req=extend_input_len_per_req,
                extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1605
                bid=bid,
1606
                can_run_cuda_graph=can_run_cuda_graph,
1607
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1608
1609
1610
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1611
1612
1613
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1614
        return ret
Chayenne's avatar
Chayenne committed
1615

1616
1617
1618
1619
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
1620
        launch_done: Optional[threading.Event] = None,
1621
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1622
        if batch.forward_mode.is_decode():
1623
            self.process_batch_result_decode(batch, result, launch_done)
1624
        elif batch.forward_mode.is_extend():
1625
            self.process_batch_result_prefill(batch, result, launch_done)
1626
1627
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1628
                self.tp_worker.resolve_last_batch_result(launch_done)
1629
                self.set_next_batch_sampling_info_done(batch)
1630
        elif batch.forward_mode.is_dummy_first():
1631
            self.set_next_batch_sampling_info_done(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1632

1633
1634
1635
1636
1637
1638
1639
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

1640
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
1641
1642
1643
1644
        return self.prepare_dp_attn_batch_raw(
            local_batch,
            dp_size=self.server_args.dp_size,
            attn_tp_size=self.attn_tp_size,
1645
            moe_dense_tp_size=self.server_args.moe_dense_tp_size,
1646
1647
1648
1649
1650
            tp_cpu_group=self.tp_cpu_group,
            get_idle_batch=self.get_idle_batch,
            disable_cuda_graph=self.server_args.disable_cuda_graph,
            spec_algorithm=self.spec_algorithm,
            speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1651
1652
1653
            enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
            enable_deepep_moe=self.server_args.enable_deepep_moe,
            deepep_mode=DeepEPMode[self.server_args.deepep_mode],
1654
1655
1656
1657
1658
1659
1660
        )

    @staticmethod
    def prepare_dp_attn_batch_raw(
        local_batch: ScheduleBatch,
        dp_size,
        attn_tp_size: int,
1661
        moe_dense_tp_size: Optional[int],
1662
1663
1664
1665
1666
        tp_cpu_group,
        get_idle_batch,
        disable_cuda_graph: bool,
        spec_algorithm,
        speculative_num_draft_tokens,
1667
1668
1669
        enable_two_batch_overlap: bool,
        enable_deepep_moe: bool,
        deepep_mode: DeepEPMode,
1670
    ):
1671
1672
1673
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
1674
            num_tokens_for_logprob = 0
1675
1676
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
1677
1678
            if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
                num_tokens = num_tokens * speculative_num_draft_tokens
1679
            num_tokens_for_logprob = num_tokens
1680
1681
        else:
            num_tokens = local_batch.extend_num_tokens
1682
            num_tokens_for_logprob = sum(
Lianmin Zheng's avatar
Lianmin Zheng committed
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
                [
                    # We should have at least 1 token for sample in every case.
                    max(extend_len - logprob_start_len, 1)
                    for logprob_start_len, extend_len in zip(
                        local_batch.extend_logprob_start_lens, local_batch.extend_lens
                    )
                ]
            )

        if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
            can_cuda_graph = 1
        else:
            can_cuda_graph = 0

1697
        if not spec_algorithm.is_none():
1698
            # TODO(sang): Support cuda graph when idle batch is there.
Lianmin Zheng's avatar
Lianmin Zheng committed
1699
1700
            if local_batch is None or local_batch.forward_mode.is_idle():
                can_cuda_graph = 0
1701

Lianmin Zheng's avatar
Lianmin Zheng committed
1702
1703
1704
        is_extend_in_batch = (
            local_batch.forward_mode.is_extend() if local_batch else False
        )
1705
1706
1707

        tbo_preparer = TboDPAttentionPreparer()

Lianmin Zheng's avatar
Lianmin Zheng committed
1708
1709
1710
1711
        local_info = torch.tensor(
            [
                num_tokens,
                can_cuda_graph,
1712
                num_tokens_for_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
1713
                is_extend_in_batch,
1714
1715
1716
1717
1718
1719
                *tbo_preparer.prepare_all_gather(
                    local_batch,
                    deepep_mode,
                    enable_deepep_moe,
                    enable_two_batch_overlap,
                ),
Lianmin Zheng's avatar
Lianmin Zheng committed
1720
1721
1722
1723
            ],
            dtype=torch.int64,
        )
        global_info = torch.empty(
1724
            (dp_size, attn_tp_size, 6),
Lianmin Zheng's avatar
Lianmin Zheng committed
1725
1726
            dtype=torch.int64,
        )
1727
        torch.distributed.all_gather_into_tensor(
Lianmin Zheng's avatar
Lianmin Zheng committed
1728
1729
            global_info.flatten(),
            local_info,
1730
            group=tp_cpu_group,
1731
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1732
1733
1734
1735
        global_num_tokens = global_info[:, 0, 0].tolist()
        can_cuda_graph = min(global_info[:, 0, 1].tolist())
        global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
        is_extend_in_batch = global_info[:, 0, 3].tolist()
1736

1737
1738
1739
1740
        tbo_split_seq_index, global_forward_mode = tbo_preparer.compute_output(
            global_info[:, :, 4:6]
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1741
        if local_batch is None and max(global_num_tokens) > 0:
1742
            local_batch = get_idle_batch()
1743
1744

        if local_batch is not None:
1745
1746
1747
1748
1749
1750
1751
1752
1753
            # TODO: handle the case when moe_dense_tp_size != 1
            if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
                local_batch.global_num_tokens = [num_tokens]
                local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
            else:
                local_batch.global_num_tokens = global_num_tokens
                local_batch.global_num_tokens_for_logprob = (
                    global_num_tokens_for_logprob
                )
1754
1755
            local_batch.tbo_split_seq_index = tbo_split_seq_index
            local_batch.global_forward_mode = global_forward_mode
1756

1757
            # Check forward mode for cuda graph
1758
            if not disable_cuda_graph:
Lianmin Zheng's avatar
Lianmin Zheng committed
1759
                local_batch.can_run_dp_cuda_graph = can_cuda_graph
1760

Lianmin Zheng's avatar
Lianmin Zheng committed
1761
        return local_batch, any(is_extend_in_batch)
1762
1763
1764
1765
1766

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
1767
            self.token_to_kv_pool_allocator,
1768
1769
1770
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
1771
            self.spec_algorithm,
1772
            self.server_args.enable_custom_logit_processor,
1773
1774
1775
1776
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1777
1778
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
1779

1780
        num_ready_reqs = 0
1781
        num_timeout_reqs = 0
1782
1783
        for req in self.grammar_queue:
            try:
1784
1785
1786
                if req.finished():  # It is aborted by AbortReq
                    num_ready_reqs += 1
                    continue
1787
                req.grammar = req.grammar.result(timeout=0.03)
1788
1789
1790
1791
1792
                self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
                if req.grammar is INVALID_GRAMMAR_OBJ:
                    req.set_finish_with_abort(
                        f"Invalid grammar request: {req.grammar_key=}"
                    )
1793
1794
                num_ready_reqs += 1
            except futures._base.TimeoutError:
1795
                req.grammar_wait_ct += 1
1796
1797
                # NOTE(lianmin): this timeout is the waiting time of the above line. It is
                # not the waiting time from it enters the grammar queue.
1798
                if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
1799
                    num_timeout_reqs = 1
1800
1801
                break

1802
        if self.server_args.enable_dp_attention:
1803
1804
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
1805
        else:
1806
1807
1808
1809
1810
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
1811
            tensor = torch.tensor([num_ready_reqs, num_timeout_reqs], dtype=torch.int32)
1812
1813
1814
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
1815
            num_ready_reqs_max, num_timeout_reqs_max = tensor.tolist()
1816

1817
            for i in range(num_ready_reqs, num_ready_reqs_max):
1818
                req = self.grammar_queue[i]
1819
1820
                if req.finished():  # It is aborted by AbortReq
                    continue
1821
                req.grammar = req.grammar.result()
1822
1823
1824
1825
1826
1827
1828
1829
                self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
                if req.grammar is INVALID_GRAMMAR_OBJ:
                    req.set_finish_with_abort(
                        f"Invalid grammar request: {req.grammar_key=}"
                    )
        else:
            num_ready_reqs_max = num_ready_reqs
            num_timeout_reqs_max = num_timeout_reqs
1830

1831
1832
1833
1834
1835
1836
1837
        for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
            req = self.grammar_queue[i]
            req.grammar.cancel()
            error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
            req.set_finish_with_abort(error_msg)
            self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
        num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
1838

1839
        self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1840
1841
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1842
1843
1844
1845
1846
1847
1848
    def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
        if batch.next_batch_sampling_info:
            if batch.next_batch_sampling_info.grammars is not None:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
                self.current_stream.synchronize()
            batch.next_batch_sampling_info.sampling_info_done.set()

1849
1850
1851
    def watchdog_thread(self):
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
        self.watchdog_last_forward_ct = 0
1852
        self.watchdog_last_time = time.perf_counter()
1853
1854

        while True:
1855
            current = time.perf_counter()
1856
1857
1858
1859
1860
1861
1862
1863
1864
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if current > self.watchdog_last_time + self.watchdog_timeout:
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)

Lianmin Zheng's avatar
Lianmin Zheng committed
1865
1866
1867
1868
1869
1870
1871
1872
1873
        if not disable_request_logging():
            # Print batch size and memory pool info to check whether there are de-sync issues.
            logger.error(
                f"{self.cur_batch.batch_size()=}, "
                f"{self.cur_batch.reqs=}, "
                f"{self.token_to_kv_pool_allocator.available_size()=}, "
                f"{self.tree_cache.evictable_size()=}, "
            )

1874
        pyspy_dump_schedulers()
Lianmin Zheng's avatar
Lianmin Zheng committed
1875
        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
1876
1877
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
1878
1879

        # Wait for some time so that the parent process can print the error.
1880
1881
1882
        time.sleep(5)
        self.parent_process.send_signal(signal.SIGQUIT)

1883
1884
1885
    def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
        success = self.flush_cache()
        return FlushCacheReqOutput(success=success)
1886

1887
    def flush_cache(self):
1888
        """Flush the memory pool and cache."""
1889
1890
1891
1892
1893
        if (
            len(self.waiting_queue) == 0
            and self.running_batch.is_empty()
            and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
        ):
1894
1895
            self.cur_batch = None
            self.last_batch = None
1896
            self.tree_cache.reset()
1897
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1898
                self.grammar_backend.reset()
1899
            self.req_to_token_pool.clear()
1900
            self.token_to_kv_pool_allocator.clear()
1901
1902
1903

            if not self.spec_algorithm.is_none():
                self.draft_worker.model_runner.req_to_token_pool.clear()
1904
                self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
1905
1906
1907
1908
1909

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
1910
1911
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
1912
1913
1914
1915
1916
1917
1918
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
1919
                f"#running-req: {len(self.running_batch.reqs)}"
1920
1921
1922
1923
            )
            if_success = False
        return if_success

Liangsheng Yin's avatar
Liangsheng Yin committed
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
    def get_load(self):
        # TODO(lsyin): use dynamically maintained num_waiting_tokens
        load = (
            self.max_total_num_tokens
            - self.token_to_kv_pool_allocator.available_size()
            - self.tree_cache.evictable_size()
        )
        load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            load += sum(
                len(req.origin_input_ids)
                for req in self.disagg_prefill_bootstrap_queue.queue
            )
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            load += sum(
                len(req.req.origin_input_ids)
                for req in self.disagg_decode_prealloc_queue.queue
            )

        return load

1945
1946
1947
1948
1949
1950
1951
1952
1953
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )
        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
Liangsheng Yin's avatar
Liangsheng Yin committed
1954
1955
1956
1957

        ret["load"] = self.get_load()

        return GetInternalStateReqOutput(internal_state=ret)
1958
1959
1960
1961
1962

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
1963
                "max_micro_batch_size",
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
1974
1975
1976
1977
1978
1979
1980
1981
            elif k == "max_micro_batch_size" and (
                v > self.max_running_requests // self.pp_size or v < 1
            ):
                logging.warning(
                    f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
                )
                if_success = False
                break
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
            logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
    def handle_rpc_request(self, recv_req: RpcReqInput):
        # Handle RPC requests
        logger.info(
            f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
        )

        success = True
        exec = None
        try:
            func = getattr(self, recv_req.method)
            func(recv_req.parameters)
        except Exception as e:
            success = False
            exec = e
            logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")

        barrier()
        return RpcReqOutput(success, "" if not exec else str(exec))

    def save_remote_model(self, params):
        url = params["url"]

2019
        worker = self.tp_worker.worker
2020
2021
2022
2023

        worker.model_runner.save_remote_model(url)

    def save_sharded_model(self, params):
2024
        worker = self.tp_worker.worker
2025
2026
2027
2028
2029
2030
2031

        worker.model_runner.save_sharded_model(
            path=params["path"],
            pattern=params["pattern"],
            max_size=params["max_size"],
        )

2032
2033
    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
Lianmin Zheng's avatar
Lianmin Zheng committed
2034
        to_del = []
2035
        for i, req in enumerate(self.waiting_queue):
Lianmin Zheng's avatar
Lianmin Zheng committed
2036
2037
            if req.rid.startswith(recv_req.rid):
                to_del.append(i)
2038

Lianmin Zheng's avatar
Lianmin Zheng committed
2039
        # Sort in reverse order to avoid index issues when deleting
Lianmin Zheng's avatar
Lianmin Zheng committed
2040
        for i in reversed(to_del):
Lianmin Zheng's avatar
Lianmin Zheng committed
2041
            req = self.waiting_queue.pop(i)
Lianmin Zheng's avatar
Lianmin Zheng committed
2042
            self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
2043
            logger.debug(f"Abort queued request. {req.rid=}")
2044
2045

        # Delete requests in the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
2046
2047
2048
2049
2050
2051
        if self.cur_batch is self.running_batch or self.cur_batch is None:
            reqs = self.running_batch.reqs
        else:
            reqs = self.running_batch.reqs + self.cur_batch.reqs

        for req in reqs:
Lianmin Zheng's avatar
Lianmin Zheng committed
2052
2053
            if req.rid.startswith(recv_req.rid) and not req.finished():
                logger.debug(f"Abort running request. {req.rid=}")
2054
                # We must use to_abort because it is in a running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
2055
                req.to_abort = True
2056

2057
2058
2059
2060
2061
2062
2063
        # Delete the requests in the grammar queue
        for req in self.grammar_queue:
            if req.rid.startswith(recv_req.rid):
                logger.debug(f"Abort grammar queue request. {req.rid=}")
                req.grammar.cancel()
                req.set_finish_with_abort("Aborted by AbortReq.")

2064
2065
2066
    def _pause_engine(self) -> Tuple[List[Req], int]:
        raise NotImplementedError()

Chayenne's avatar
Chayenne committed
2067
2068
2069
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
2070
2071
2072
2073
2074
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
2075
        return UpdateWeightFromDiskReqOutput(success, message, 0)
2076

2077
2078
2079
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
2080
        return InitWeightsUpdateGroupReqOutput(success, message)
2081
2082

    def update_weights_from_distributed(
2083
2084
2085
        self,
        recv_req: UpdateWeightsFromDistributedReqInput,
    ) -> Tuple[bool, str]:
2086
2087
2088
2089
2090
2091
2092
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
2093
        return UpdateWeightsFromDistributedReqOutput(success, message)
2094

2095
2096
2097
2098
2099
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        """Update the online model parameter from tensors."""
        success, message = self.tp_worker.update_weights_from_tensor(recv_req)
        # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
        if success:
2100
2101
2102
            if recv_req.flush_cache:
                flash_cache_success = self.flush_cache()
                assert flash_cache_success, "Cache flush failed after updating weights"
2103
2104
        else:
            logger.error(message)
2105
        return UpdateWeightsFromTensorReqOutput(success, message)
2106

2107
2108
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
2109
        return GetWeightsByNameReqOutput(parameter)
2110

2111
    def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
2112
2113
2114
        self.memory_saver_adapter.check_validity(
            caller_name="release_memory_occupation"
        )
2115
2116
2117
2118
2119
        self.stashed_model_static_state = _export_static_state(
            self.tp_worker.worker.model_runner.model
        )
        self.memory_saver_adapter.pause()
        self.flush_cache()
2120
        return ReleaseMemoryOccupationReqOutput()
2121

2122
    def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
2123
        self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
2124
2125
2126
2127
2128
        self.memory_saver_adapter.resume()
        _import_static_state(
            self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
        )
        del self.stashed_model_static_state
2129
2130
        return ResumeMemoryOccupationReqOutput()

2131
2132
2133
2134
2135
2136
2137
    def slow_down(self, recv_req: SlowDownReqInput):
        t = recv_req.forward_sleep_time
        if t is not None and t <= 0:
            t = None
        self.forward_sleep_time = t
        return SlowDownReqOutput()

2138
    def profile(self, recv_req: ProfileReq):
2139
        if recv_req.type == ProfileReqType.START_PROFILE:
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
            if recv_req.profile_by_stage:
                return self.init_profile(
                    recv_req.output_dir,
                    recv_req.num_steps,
                    recv_req.activities,
                    recv_req.with_stack,
                    recv_req.record_shapes,
                    recv_req.profile_by_stage,
                )
            else:
                self.init_profile(
                    recv_req.output_dir,
                    recv_req.num_steps,
                    recv_req.activities,
                    recv_req.with_stack,
                    recv_req.record_shapes,
                    recv_req.profile_by_stage,
                )
                return self.start_profile(True)
2159
        else:
2160
2161
            return self.stop_profile()

2162
    def init_profile(
2163
2164
2165
2166
        self,
        output_dir: Optional[str],
        num_steps: Optional[int],
        activities: Optional[List[str]],
2167
2168
        with_stack: Optional[bool],
        record_shapes: Optional[bool],
2169
2170
2171
        profile_by_stage: bool,
    ) -> ProfileReqOutput:
        if self.profile_in_progress:
2172
2173
2174
2175
2176
            return ProfileReqOutput(
                success=False,
                message="Profiling is already in progress. Call /stop_profile first.",
            )

2177
2178
        self.profile_by_stage = profile_by_stage

2179
2180
2181
2182
2183
2184
        if output_dir is None:
            output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
        if activities is None:
            activities = ["CPU", "GPU"]

        self.torch_profiler_output_dir = output_dir
2185
2186
        self.torch_profiler_with_stack = with_stack
        self.torch_profiler_record_shapes = record_shapes
2187
        self.profiler_activities = activities
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207

        if num_steps:
            self.profile_steps = num_steps
            if self.profile_by_stage:
                self.profiler_target_prefill_ct = num_steps
                self.profiler_target_decode_ct = num_steps
                self.profiler_prefill_ct = 0
                self.profiler_decode_ct = 0
            else:
                self.profiler_target_forward_ct = self.forward_ct + num_steps
            # The caller will be notified when reaching profiler_target_forward_ct
        else:
            self.profiler_target_forward_ct = None

        return ProfileReqOutput(success=True, message="Succeeded")

    def start_profile(
        self, stage: Optional[ForwardMode] = None
    ) -> ProfileReqOutput | None:
        stage_str = f" for {stage.__str__()}" if stage else ""
2208
        logger.info(
2209
            f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir}",
2210
2211
        )

2212
2213
2214
2215
        activities = self.profiler_activities
        with_stack = self.torch_profiler_with_stack
        record_shapes = self.torch_profiler_record_shapes

2216
2217
2218
2219
2220
2221
2222
2223
        activity_map = {
            "CPU": torch.profiler.ProfilerActivity.CPU,
            "GPU": torch.profiler.ProfilerActivity.CUDA,
        }
        torchprof_activities = [
            activity_map[a] for a in activities if a in activity_map
        ]

2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
        if "RPD" in activities:
            from rpdTracerControl import rpdTracerControl

            rpdTracerControl.skipCreate()

            self.rpd_profile_path = os.path.join(
                self.torch_profiler_output_dir,
                "rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
            )

            if self.tp_rank == 0:
                import sqlite3

                from rocpd.schema import RocpdSchema

                if os.path.exists("trace.rpd"):
                    os.unlink("trace.rpd")
                schema = RocpdSchema()
                connection = sqlite3.connect("trace.rpd")
                schema.writeSchema(connection)
                connection.commit()
                del connection
            torch.distributed.barrier(self.tp_cpu_group)

            self.rpd_profiler = rpdTracerControl()
            self.rpd_profiler.setPythonTrace(True)
            self.rpd_profiler.start()
            self.rpd_profiler.rangePush("", "rpd profile range", "")
            self.profile_in_progress = True
        elif torchprof_activities:
2254
2255
            self.torch_profiler = torch.profiler.profile(
                activities=torchprof_activities,
2256
2257
                with_stack=with_stack if with_stack is not None else True,
                record_shapes=record_shapes if record_shapes is not None else False,
2258
2259
            )
            self.torch_profiler.start()
2260
            self.profile_in_progress = True
2261
2262
2263

        if "MEM" in activities:
            torch.cuda.memory._record_memory_history(max_entries=100000)
2264
            self.profile_in_progress = True
2265

2266
2267
2268
        if "CUDA_PROFILER" in activities:
            torch.cuda.cudart().cudaProfilerStart()

2269
        return ProfileReqOutput(success=True, message="Succeeded")
2270

2271
2272
2273
2274
    def stop_profile(
        self, stage: Optional[ForwardMode] = None
    ) -> ProfileReqOutput | None:
        if not self.profile_in_progress:
2275
2276
2277
2278
            return ProfileReqOutput(
                success=False,
                message="Profiling is not in progress. Call /start_profile first.",
            )
2279

2280
2281
        stage_suffix = f"-{stage.__str__()}" if stage else ""
        logger.info("Stop profiling" + stage_suffix + "...")
2282
2283
2284
2285
2286
        if self.torch_profiler is not None:
            self.torch_profiler.stop()
            self.torch_profiler.export_chrome_trace(
                os.path.join(
                    self.torch_profiler_output_dir,
2287
2288
2289
2290
                    str(time.time())
                    + f"-TP-{self.tp_rank}"
                    + stage_suffix
                    + ".trace.json.gz",
2291
2292
                )
            )
2293
2294
2295
2296
2297
2298
            torch.distributed.barrier(self.tp_cpu_group)

        if self.rpd_profiler is not None:
            self.rpd_profiler.rangePop()
            self.rpd_profiler.stop()
            self.rpd_profiler.flush()
2299

2300
2301
2302
2303
2304
2305
2306
2307
2308
            torch.distributed.barrier(self.tp_cpu_group)
            if self.tp_rank == 0:
                from sglang.srt.utils import rpd_to_chrome_trace

                rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
            self.rpd_profiler = None
            self.rpd_profiler_path = None

        if self.profiler_activities is not None and "MEM" in self.profiler_activities:
2309
            memory_profile_path = os.path.join(
2310
                self.torch_profiler_output_dir,
2311
2312
2313
2314
                str(time.time())
                + f"-TP-{self.tp_rank}-memory"
                + stage_suffix
                + ".pickle",
2315
2316
2317
2318
            )
            torch.cuda.memory._dump_snapshot(memory_profile_path)
            torch.cuda.memory._record_memory_history(enabled=None)

2319
2320
2321
        if "CUDA_PROFILER" in self.profiler_activities:
            torch.cuda.cudart().cudaProfilerStop()

2322
2323
2324
        logger.info(
            "Profiling done. Traces are saved to: %s",
            self.torch_profiler_output_dir,
2325
        )
2326
        self.torch_profiler = None
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
        self.profile_in_progress = False

        return ProfileReqOutput(success=True, message="Succeeded.")

    def _profile_batch_predicate(self, batch):
        if self.profile_by_stage:
            if batch.forward_mode.is_prefill():
                if self.profiler_prefill_ct == 0:
                    self.start_profile(batch.forward_mode)
                self.profiler_prefill_ct += 1
                if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
                    if self.profile_in_progress:
                        self.stop_profile(stage=ForwardMode.EXTEND)
            elif batch.forward_mode.is_decode():
                if self.profiler_decode_ct == 0:
                    if self.profile_in_progress:
                        # force trace flush
                        self.stop_profile(ForwardMode.EXTEND)
                    self.start_profile(batch.forward_mode)
                self.profiler_decode_ct += 1
                if self.profiler_decode_ct > self.profiler_target_decode_ct:
                    if self.profile_in_progress:
                        self.stop_profile(stage=ForwardMode.DECODE)
            else:
                raise RuntimeError("unsupported profile stage")
        else:
            # Check profiler
            if (
                self.profiler_target_forward_ct
                and self.profiler_target_forward_ct <= self.forward_ct
            ):
                self.stop_profile()
2359

2360
2361
    def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
        if recv_req == ExpertDistributionReq.START_RECORD:
2362
            get_global_expert_distribution_recorder().start_record()
2363
        elif recv_req == ExpertDistributionReq.STOP_RECORD:
2364
            get_global_expert_distribution_recorder().stop_record()
2365
        elif recv_req == ExpertDistributionReq.DUMP_RECORD:
2366
            get_global_expert_distribution_recorder().dump_record()
2367
2368
        else:
            raise ValueError("Unrecognized ExpertDistributionReq value")
2369
        return ExpertDistributionReqOutput()
2370

2371
    def open_session(self, recv_req: OpenSessionReqInput):
2372
2373
2374
2375
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
2376
            return OpenSessionReqOutput(session_id, False)
2377
        elif session_id is None:
2378
            logger.warning("session id is None, cannot open.")
2379
            return OpenSessionReqOutput(session_id, False)
2380
2381
2382
2383
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
2384
            return OpenSessionReqOutput(session_id, True)
2385
2386
2387
2388
2389
2390
2391
2392
2393

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

2394
2395
    def get_print_prefix(self):
        prefix = ""
2396
2397
        if self.attn_dp_rank is not None:
            prefix += f" DP{self.attn_dp_rank}"
2398
2399
2400
2401
2402
2403
        if self.server_args.tp_size > 1:
            prefix += f" TP{self.tp_rank}"
        if self.pp_size > 1:
            prefix += f" PP{self.pp_rank}"
        return prefix

2404
2405
2406
2407
2408
2409
2410
    def _publish_kv_events(self):
        if self.enable_kv_cache_events:
            events = self.tree_cache.take_events()
            if events:
                batch = KVEventBatch(ts=time.time(), events=events)
                self.kv_event_publisher.publish(batch)

2411

2412
2413
2414
2415
def is_health_check_generate_req(recv_req):
    return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")


2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
def _export_static_state(model):
    return dict(
        buffers=[
            (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
        ]
    )


def _import_static_state(model, static_params):
    self_named_buffers = dict(model.named_buffers())
    for name, tensor in static_params["buffers"]:
        self_named_buffers[name][...] = tensor


2430
2431
2432
2433
2434
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
2435
    pp_rank: int,
2436
    dp_rank: Optional[int],
2437
    pipe_writer,
2438
):
2439
    # Generate the prefix
2440
2441
2442
2443
2444
2445
2446
    prefix = ""
    if dp_rank is not None:
        prefix += f" DP{dp_rank}"
    if server_args.tp_size > 1:
        prefix += f" TP{tp_rank}"
    if server_args.pp_size > 1:
        prefix += f" PP{pp_rank}"
2447

2448
    # Config the process
2449
    kill_itself_when_parent_died()
2450
    setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
2451
    faulthandler.enable()
2452
    parent_process = psutil.Process().parent()
2453

2454
2455
2456
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
2457

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
2458
    # Configure the logger
2459
    configure_logger(server_args, prefix=prefix)
2460
    suppress_other_loggers()
2461

2462
    # Set cpu affinity to this gpu process
2463
2464
2465
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

2466
2467
2468
2469
    embedding_cache_size = 100
    if "SGLANG_VLM_CACHE_SIZE_MB" in os.environ:
        embedding_cache_size = int(os.environ["SGLANG_VLM_CACHE_SIZE_MB"])
    init_embedding_cache(embedding_cache_size * 1024 * 1024)
2470
    # Create a scheduler and run the event loop
2471
    try:
2472
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
2473
        pipe_writer.send(
Mick's avatar
Mick committed
2474
2475
2476
2477
2478
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2479
        )
Byron Hsu's avatar
Byron Hsu committed
2480
2481
2482
        disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode

        if disaggregation_mode == DisaggregationMode.NULL:
2483
2484
2485
            if server_args.pp_size > 1:
                scheduler.event_loop_pp()
            elif scheduler.enable_overlap:
Byron Hsu's avatar
Byron Hsu committed
2486
2487
2488
2489
                scheduler.event_loop_overlap()
            else:
                scheduler.event_loop_normal()
        elif disaggregation_mode == DisaggregationMode.PREFILL:
2490
2491
2492
2493
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_prefill()
            else:
                scheduler.event_loop_normal_disagg_prefill()
2494

Byron Hsu's avatar
Byron Hsu committed
2495
        elif disaggregation_mode == DisaggregationMode.DECODE:
2496
2497
2498
2499
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_decode()
            else:
                scheduler.event_loop_normal_disagg_decode()
Byron Hsu's avatar
Byron Hsu committed
2500

2501
    except Exception:
2502
2503
2504
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)