scheduler.py 93.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
import time
23
from collections import defaultdict, deque
Lianmin Zheng's avatar
Lianmin Zheng committed
24
from concurrent import futures
25
from dataclasses import dataclass
26
from http import HTTPStatus
27
from types import SimpleNamespace
28
from typing import Dict, List, Optional, Tuple, Union
29

30
import psutil
31
import setproctitle
32
import torch
33
import zmq
34
from torch.distributed import barrier
35

36
from sglang.global_config import global_config
37
from sglang.srt import two_batch_overlap
Lianmin Zheng's avatar
Lianmin Zheng committed
38
from sglang.srt.configs.model_config import ModelConfig
39
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
Byron Hsu's avatar
Byron Hsu committed
40
41
42
43
44
from sglang.srt.disaggregation.decode import (
    DecodePreallocQueue,
    DecodeTransferQueue,
    SchedulerDisaggregationDecodeMixin,
)
45
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
Byron Hsu's avatar
Byron Hsu committed
46
47
48
49
50
51
from sglang.srt.disaggregation.prefill import (
    PrefillBootstrapQueue,
    SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
52
    MetadataBuffers,
Byron Hsu's avatar
Byron Hsu committed
53
    ReqToMetadataIdxAllocator,
54
    TransferBackend,
55
    prepare_abort,
Byron Hsu's avatar
Byron Hsu committed
56
)
57
from sglang.srt.distributed import get_pp_group, get_world_group
xm:D's avatar
xm:D committed
58
59
60
61
62
from sglang.srt.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
63
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
64
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
65
66
67
68
from sglang.srt.managers.expert_distribution import (
    ExpertDistributionRecorder,
    get_global_expert_distribution_recorder,
)
69
70
from sglang.srt.managers.io_struct import (
    AbortReq,
71
    CloseSessionReqInput,
72
    ExpertDistributionReq,
73
    ExpertDistributionReqOutput,
74
75
    FlushCacheReqInput,
    FlushCacheReqOutput,
76
77
    GetInternalStateReq,
    GetInternalStateReqOutput,
78
79
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
80
    HealthCheckOutput,
81
82
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
83
84
    OpenSessionReqInput,
    OpenSessionReqOutput,
85
    ProfileReq,
86
87
    ProfileReqOutput,
    ProfileReqType,
88
89
90
91
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
92
93
    RpcReqInput,
    RpcReqOutput,
94
95
    SetInternalStateReq,
    SetInternalStateReqOutput,
96
97
    SlowDownReqInput,
    SlowDownReqOutput,
98
99
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
100
101
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
102
103
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
104
105
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
106
)
107
from sglang.srt.managers.mm_utils import init_embedding_cache
108
109
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
Mick's avatar
Mick committed
110
    MultimodalInputs,
111
112
    Req,
    ScheduleBatch,
113
    global_server_args_dict,
114
)
115
116
117
118
119
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
120
121
122
from sglang.srt.managers.scheduler_output_processor_mixin import (
    SchedulerOutputProcessorMixin,
)
123
from sglang.srt.managers.session_controller import Session
124
from sglang.srt.managers.tp_worker import TpModelWorker
125
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
126
from sglang.srt.managers.utils import validate_input_length
127
from sglang.srt.mem_cache.chunk_cache import ChunkCache
128
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
129
from sglang.srt.mem_cache.radix_cache import RadixCache
130
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
Lianmin Zheng's avatar
Lianmin Zheng committed
131
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
132
from sglang.srt.reasoning_parser import ReasoningParser
133
from sglang.srt.server_args import PortArgs, ServerArgs
134
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
135
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
136
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
137
from sglang.srt.utils import (
138
    DeepEPMode,
139
    DynamicGradMode,
140
141
    broadcast_pyobj,
    configure_logger,
Lianmin Zheng's avatar
Lianmin Zheng committed
142
    disable_request_logging,
143
    get_bool_env_var,
144
    get_zmq_socket,
Lianmin Zheng's avatar
Lianmin Zheng committed
145
    kill_itself_when_parent_died,
146
    point_to_point_pyobj,
147
    pyspy_dump_schedulers,
148
    set_gpu_proc_affinity,
149
150
151
    set_random_seed,
    suppress_other_loggers,
)
152
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
153
154
155

logger = logging.getLogger(__name__)

156
# Test retract decode for debugging purposes
157
158
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
159
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
160

161

162
163
@dataclass
class GenerationBatchResult:
164
165
166
    logits_output: Optional[LogitsProcessorOutput]
    pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
    next_token_ids: Optional[List[int]]
167
168
    extend_input_len_per_req: List[int]
    extend_logprob_start_len_per_req: List[int]
169
    bid: int
170
    can_run_cuda_graph: bool
171
172
173
174
175
176
177
178


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


Byron Hsu's avatar
Byron Hsu committed
179
180
181
182
183
class Scheduler(
    SchedulerOutputProcessorMixin,
    SchedulerDisaggregationDecodeMixin,
    SchedulerDisaggregationPrefillMixin,
):
184
185
186
187
188
189
190
191
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
192
        pp_rank: int,
193
        dp_rank: Optional[int],
194
195
    ):
        # Parse args
196
        self.server_args = server_args
197
        self.tp_rank = tp_rank
198
        self.pp_rank = pp_rank
199
        self.tp_size = server_args.tp_size
200
201
        self.pp_size = server_args.pp_size
        self.dp_size = server_args.dp_size
202
203
204
        self.schedule_policy = server_args.schedule_policy
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
205
        self.enable_overlap = not server_args.disable_overlap_schedule
206
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
207
        self.enable_metrics = server_args.enable_metrics
208
        self.enable_kv_cache_events = server_args.kv_events_config is not None
209
        self.stream_interval = server_args.stream_interval
210
211
212
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
213
214
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
Lianmin Zheng's avatar
Lianmin Zheng committed
215
        self.page_size = server_args.page_size
216
        # Distributed rank info
217
218
        self.dp_size = server_args.dp_size
        self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
219
220
221
222
223
224
225
226
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

227
228
        # Init inter-process communication
        context = zmq.Context(2)
229
        if self.pp_rank == 0 and self.attn_tp_rank == 0:
230
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
231
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
232
            )
233
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
234
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
235
            )
236

237
            if server_args.skip_tokenizer_init:
238
                # Directly send to the TokenizerManager
239
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
240
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
241
242
                )
            else:
243
                # Send to the DetokenizerManager
244
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
245
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
246
                )
247
248
249
250

            self.recv_from_rpc = get_zmq_socket(
                context, zmq.DEALER, port_args.rpc_ipc_name, False
            )
251
        else:
252
            self.recv_from_tokenizer = None
253
            self.recv_from_rpc = None
254
255
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
256
257

        # Init tokenizer
258
        self.init_tokenizer()
259

260
261
262
263
264
265
266
267
268
        # Set reasoning_parser and think_end_id if --reasoning_parser is enabled
        if self.server_args.reasoning_parser and self.tokenizer:
            reasoning_parser = ReasoningParser(
                model_type=self.server_args.reasoning_parser, stream_reasoning=False
            )
            self.tokenizer.think_end_id = self.tokenizer.encode(
                reasoning_parser.detector.think_end_token, add_special_tokens=False
            )[0]

269
270
271
272
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
273

274
        # Launch a tensor parallel worker
275
        if self.enable_overlap:
276
            TpWorkerClass = TpModelWorkerClient
277
278
        else:
            TpWorkerClass = TpModelWorker
279

280
        self.tp_worker = TpWorkerClass(
281
            server_args=server_args,
282
283
            gpu_id=gpu_id,
            tp_rank=tp_rank,
284
            pp_rank=pp_rank,
285
            dp_rank=dp_rank,
286
            nccl_port=port_args.nccl_port,
287
        )
288

289
        # Launch a draft worker for speculative decoding
290
291
292
293
294
295
296
297
298
299
300
301
302
303
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
        else:
            self.draft_worker = None

304
        # Get token and memory info from the model worker
305
306
307
308
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
309
            self.max_req_len,
310
311
            self.max_req_input_len,
            self.random_seed,
312
            self.device,
313
314
315
316
317
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
318
319
320
321
322
323
324
325
        if global_server_args_dict["max_micro_batch_size"] is None:
            global_server_args_dict["max_micro_batch_size"] = max(
                self.max_running_requests // server_args.pp_size, 1
            )

        self.tp_group = self.tp_worker.get_tp_group()
        self.tp_cpu_group = self.tp_group.cpu_group
        self.attn_tp_group = self.tp_worker.get_attention_tp_group()
326
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
327
328
329
        self.pp_group = get_pp_group()
        self.world_group = get_world_group()

330
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
331
        global_server_args_dict.update(worker_global_server_args_dict)
332
        set_random_seed(self.random_seed)
333

334
        # Print debug info
335
336
337
338
339
340
341
342
        if tp_rank == 0:
            logger.info(
                f"max_total_num_tokens={self.max_total_num_tokens}, "
                f"chunked_prefill_size={server_args.chunked_prefill_size}, "
                f"max_prefill_tokens={self.max_prefill_tokens}, "
                f"max_running_requests={self.max_running_requests}, "
                f"context_len={self.model_config.context_len}"
            )
343

Lianmin Zheng's avatar
Lianmin Zheng committed
344
        # Init memory pool and cache
345
        self.init_memory_pool_and_cache()
346
347
348

        # Init running status
        self.waiting_queue: List[Req] = []
349
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
350
        self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
351
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
352
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
353
        # The last forward batch
354
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
355
356
        self.forward_ct = 0
        self.forward_ct_decode = 0
357
        self.num_generated_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
358
        self.num_prefill_tokens = 0
359
360
        self.last_decode_stats_tic = time.perf_counter()
        self.last_prefill_stats_tic = time.perf_counter()
361
        self.return_health_check_ct = 0
362
        self.current_stream = torch.get_device_module(self.device).current_stream()
363
364
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
365

366
        # Init session info
367
        self.sessions: Dict[str, Session] = {}
368
369
370

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
371
372
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
373
        self.chunked_req = None
374
375
376
377
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
378
        # Init the grammar backend for constrained generation
379
        self.grammar_queue: List[Req] = []
380
        if not server_args.skip_tokenizer_init:
381
382
383
            self.grammar_backend = create_grammar_backend(
                server_args, self.tokenizer, self.model_config.vocab_size
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
384
385
        else:
            self.grammar_backend = None
386

387
        # Init schedule policy and new token estimation
388
        self.policy = SchedulePolicy(
Lianmin Zheng's avatar
Lianmin Zheng committed
389
390
391
            self.schedule_policy,
            self.tree_cache,
            self.enable_hierarchical_cache,
392
        )
393
394
395
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
396
397
        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
398
399
            * server_args.schedule_conservativeness,
            1.0,
400
        )
401
402
403
404
405
406
407
408
409
410
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
411
412
413
414
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
415
        self.parent_process = psutil.Process().parent()
Lianmin Zheng's avatar
Lianmin Zheng committed
416

417
        # Init memory saver
418
419
420
421
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )

422
        # Init profiler
423
424
        self.torch_profiler = None
        self.torch_profiler_output_dir: Optional[str] = None
425
        self.profiler_activities: Optional[List[str]] = None
426
        self.profiler_id: Optional[str] = None
427
        self.profiler_target_forward_ct: Optional[int] = None
428

429
430
        self.forward_sleep_time = None

431
        # Init metrics stats
432
        self.init_metrics()
433
        self.init_kv_events(server_args.kv_events_config)
434

435
436
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
437
438
439
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
440
                (FlushCacheReqInput, self.flush_cache_wrapped),
441
                (AbortReq, self.abort_request),
442
443
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
444
445
446
447
448
449
450
451
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
452
453
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
454
                (SlowDownReqInput, self.slow_down),
455
                (ProfileReq, self.profile),
456
                (GetInternalStateReq, self.get_internal_state),
457
                (SetInternalStateReq, self.set_internal_state),
458
                (RpcReqInput, self.handle_rpc_request),
459
                (ExpertDistributionReq, self.expert_distribution_handle),
460
461
462
            ]
        )

Byron Hsu's avatar
Byron Hsu committed
463
464
465
466
467
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
        self.init_disaggregation()

468
469
    def init_tokenizer(self):
        server_args = self.server_args
Lianmin Zheng's avatar
Lianmin Zheng committed
470

471
        self.model_config = ModelConfig.from_server_args(server_args)
472
        self.is_generation = self.model_config.is_generation
473

474
475
476
477
478
479
480
481
482
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if self.model_config.is_multimodal:
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
483
                    use_fast=not server_args.disable_fast_image_processor,
484
                )
xm:D's avatar
xm:D committed
485
                self.tokenizer = get_tokenizer_from_processor(self.processor)
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )

    def init_memory_pool_and_cache(self):
        server_args = self.server_args

        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            self.tp_worker.get_memory_pool()
        )

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
508
                page_size=self.page_size,
509
510
511
512
513
514
            )
        else:
            if self.enable_hierarchical_cache:
                self.tree_cache = HiRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
515
                    tp_cache_group=self.tp_cpu_group,
516
                    page_size=self.page_size,
517
                    hicache_ratio=server_args.hicache_ratio,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
518
519
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
520
521
522
523
524
                )
            else:
                self.tree_cache = RadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
525
                    page_size=self.page_size,
526
                    disable=server_args.disable_radix_cache,
527
                    enable_kv_cache_events=self.enable_kv_cache_events,
528
529
530
531
532
533
534
535
536
537
538
539
                )

        self.decode_mem_cache_buf_multiplier = (
            1
            if self.spec_algorithm.is_none()
            else (
                server_args.speculative_num_draft_tokens
                + (
                    server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                )
            )
540
        )
541
542
543

    def init_metrics(self):
        self.last_gen_throughput: float = 0.0
Lianmin Zheng's avatar
Lianmin Zheng committed
544
        self.last_input_throughput: float = 0.0
545
546
547
548
549
550
551
552
553
554
555
556
557
558
        self.step_time_dict = defaultdict(list)  # Dict[batch size -> step time]
        self.spec_num_total_accepted_tokens = 0
        self.spec_num_total_forward_ct = 0
        self.cum_spec_accept_length = 0
        self.cum_spec_accept_count = 0
        self.stats = SchedulerStats()
        if self.enable_metrics:
            engine_type = "unified"
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    "engine_type": engine_type,
                },
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
559

560
561
562
563
    def init_kv_events(self, kv_events_config: Optional[str]):
        if self.enable_kv_cache_events:
            self.kv_event_publisher = EventPublisherFactory.create(kv_events_config)

Byron Hsu's avatar
Byron Hsu committed
564
    def init_disaggregation(self):
565
566
567
568
        self.transfer_backend = TransferBackend(
            self.server_args.disaggregation_transfer_backend
        )

Byron Hsu's avatar
Byron Hsu committed
569
570
571
572
573
574
575
        if (
            self.disaggregation_mode == DisaggregationMode.DECODE
        ):  # *2 for the headroom.
            buffer_size = (self.req_to_token_pool.size) * 2
            req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
                buffer_size
            )
576
            self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
Byron Hsu's avatar
Byron Hsu committed
577
578
579

            # The decode requests polling kv cache
            self.disagg_decode_transfer_queue = DecodeTransferQueue(
580
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
581
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
582
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
583
584
                scheduler=self,
                tree_cache=self.tree_cache,
Byron Hsu's avatar
Byron Hsu committed
585
586
587
588
589
590
            )

            # The decode requests pending for pre-allocation
            self.disagg_decode_prealloc_queue = DecodePreallocQueue(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Byron Hsu's avatar
Byron Hsu committed
591
592
593
594
595
                draft_token_to_kv_pool=(
                    None
                    if self.draft_worker is None
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
596
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
597
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
598
599
600
                scheduler=self,
                transfer_queue=self.disagg_decode_transfer_queue,
                tree_cache=self.tree_cache,
601
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
602
603
604
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
605
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
606
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
607
608
609
610

            # Metric for pre-allocation
            self.num_tokens_pre_allocated = 0

Byron Hsu's avatar
Byron Hsu committed
611
612
613
614
615
616
        elif self.disaggregation_mode == DisaggregationMode.PREFILL:
            # *2 for the headroom.
            buffer_size = self.max_running_requests * 2
            req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
                buffer_size
            )
617
            self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
Byron Hsu's avatar
Byron Hsu committed
618

Liangsheng Yin's avatar
Liangsheng Yin committed
619
            self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
Byron Hsu's avatar
Byron Hsu committed
620
                token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
Byron Hsu's avatar
Byron Hsu committed
621
622
623
624
625
                draft_token_to_kv_pool=(
                    None
                    if self.draft_worker is None
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
626
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
627
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
628
629
630
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
631
                gloo_group=self.attn_tp_cpu_group,
632
                transfer_backend=self.transfer_backend,
633
                scheduler=self,
Byron Hsu's avatar
Byron Hsu committed
634
635
            )
            # The prefill requests that are in the middle of kv sending
636
            self.disagg_prefill_inflight_queue: List[Req] = []
Byron Hsu's avatar
Byron Hsu committed
637

638
    @DynamicGradMode()
639
    def event_loop_normal(self):
640
        """A normal scheduler loop."""
641
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
642
643
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
644

645
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
646
            self.cur_batch = batch
647
648
649
650

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
651
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
652
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
653
                self.check_memory()
654
                self.new_token_ratio = self.init_new_token_ratio
655
656

            self.last_batch = batch
657

658
    @DynamicGradMode()
Lianmin Zheng's avatar
Lianmin Zheng committed
659
    def event_loop_overlap(self):
660
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
661
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
662
663
664
665
666
667
668

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
669

Lianmin Zheng's avatar
Lianmin Zheng committed
670
            if batch:
671
                batch.launch_done = threading.Event()
Lianmin Zheng's avatar
Lianmin Zheng committed
672
                result = self.run_batch(batch)
673
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
674

675
                if self.last_batch is None:
676
                    # Create a dummy first batch to start the pipeline for overlap schedule.
677
678
679
680
681
682
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
683
                    self.process_batch_result(tmp_batch, None, batch.launch_done)
684

Lianmin Zheng's avatar
Lianmin Zheng committed
685
            if self.last_batch:
686
                # Process the results of the last batch
687
                tmp_batch, tmp_result = self.result_queue.popleft()
688
689
690
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
691
692
693
694
                # NOTE: we should use current launched batch's launch_done event Instead of the last batch's
                self.process_batch_result(
                    tmp_batch, tmp_result, batch.launch_done if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
695
            elif batch is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
696
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
697
                self.check_memory()
698
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
699
700
701

            self.last_batch = batch

702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
    @DynamicGradMode()
    def event_loop_pp(self):
        """A non-overlap scheduler loop for pipeline parallelism."""
        mbs = [None] * self.pp_size
        last_mbs = [None] * self.pp_size
        self.running_mbs = [
            ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
        ]
        bids = [None] * self.pp_size
        pp_outputs: Optional[PPProxyTensors] = None
        while True:
            server_is_idle = True
            for mb_id in range(self.pp_size):
                self.running_batch = self.running_mbs[mb_id]
                self.last_batch = last_mbs[mb_id]

                recv_reqs = self.recv_requests()
                self.process_input_requests(recv_reqs)
                mbs[mb_id] = self.get_next_batch_to_run()
                self.running_mbs[mb_id] = self.running_batch

                self.cur_batch = mbs[mb_id]
                if self.cur_batch:
                    server_is_idle = False
                    result = self.run_batch(self.cur_batch)

728
                # (last rank) send the outputs to the next step
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
                if self.pp_group.is_last_rank:
                    if self.cur_batch:
                        next_token_ids, bids[mb_id] = (
                            result.next_token_ids,
                            result.bid,
                        )
                        pp_outputs = PPProxyTensors(
                            {
                                "next_token_ids": next_token_ids,
                            }
                        )
                        # send the output from the last round to let the next stage worker run post processing
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                # receive outputs and post-process (filter finished reqs) the coming microbatch
                next_mb_id = (mb_id + 1) % self.pp_size
                next_pp_outputs = None
                if mbs[next_mb_id] is not None:
                    next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
                        self.pp_group.recv_tensor_dict(
                            all_gather_group=self.attn_tp_group
                        )
                    )
                    mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
                    output_result = GenerationBatchResult(
                        logits_output=None,
                        pp_hidden_states_proxy_tensors=None,
                        next_token_ids=next_pp_outputs["next_token_ids"],
                        extend_input_len_per_req=None,
                        extend_logprob_start_len_per_req=None,
                        bid=bids[next_mb_id],
763
                        can_run_cuda_graph=result.can_run_cuda_graph,
764
765
766
767
                    )
                    self.process_batch_result(mbs[next_mb_id], output_result)
                    last_mbs[next_mb_id] = mbs[next_mb_id]

768
                # (not last rank)
769
770
771
                if not self.pp_group.is_last_rank:
                    if self.cur_batch:
                        bids[mb_id] = result.bid
772
773
                    # carry the outputs to the next stage
                    # send the outputs from the last round to let the next stage worker run post processing
774
775
776
777
778
779
780
                    if pp_outputs:
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                    # send out reqs to the next stage
781
                    dp_offset = self.attn_dp_rank * self.attn_tp_size
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
                    if self.attn_tp_rank == 0:
                        point_to_point_pyobj(
                            recv_reqs,
                            self.pp_rank * self.tp_size + dp_offset,
                            self.world_group.cpu_group,
                            self.pp_rank * self.tp_size + dp_offset,
                            (self.pp_rank + 1) * self.tp_size + dp_offset,
                        )

                    # send out proxy tensors to the next stage
                    if self.cur_batch:
                        self.pp_group.send_tensor_dict(
                            result.pp_hidden_states_proxy_tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                pp_outputs = next_pp_outputs

            # When the server is idle, self-check and re-init some states
            if server_is_idle:
                self.check_memory()
                self.new_token_ratio = self.init_new_token_ratio

805
806
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
        if self.pp_rank == 0:
            if self.attn_tp_rank == 0:
                recv_reqs = []

                while True:
                    try:
                        recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_req)

                while True:
                    try:
                        recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_rpc)
            else:
                recv_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
826
        else:
827
            if self.attn_tp_rank == 0:
828
                dp_offset = self.attn_dp_rank * self.attn_tp_size
829
830
831
832
833
834
835
836
837
                recv_reqs = point_to_point_pyobj(
                    [],
                    self.pp_rank * self.tp_size + dp_offset,
                    self.world_group.cpu_group,
                    (self.pp_rank - 1) * self.tp_size + dp_offset,
                    self.pp_rank * self.tp_size + dp_offset,
                )
            else:
                recv_reqs = None
838

839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                work_reqs = broadcast_pyobj(
                    work_reqs,
862
                    self.attn_tp_group.rank,
863
                    self.attn_tp_cpu_group,
864
                    src=self.attn_tp_group.ranks[0],
865
866
867
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
868
869
870
871
                    control_reqs,
                    self.tp_group.rank,
                    self.tp_cpu_group,
                    src=self.tp_group.ranks[0],
872
873
874
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
875
876
877
878
879
880
            recv_reqs = broadcast_pyobj(
                recv_reqs,
                self.tp_group.rank,
                self.tp_cpu_group,
                src=self.tp_group.ranks[0],
            )
881
882
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
883
    def process_input_requests(self, recv_reqs: List):
884
        for recv_req in recv_reqs:
885
886
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
Lianmin Zheng's avatar
Lianmin Zheng committed
887
                self.chunked_req is not None or not self.running_batch.is_empty()
888
889
890
891
            ):
                self.return_health_check_ct += 1
                continue

892
            output = self._request_dispatcher(recv_req)
893
            if output is not None:
894
895
896
897
898
                if isinstance(output, RpcReqOutput):
                    if self.recv_from_rpc is not None:
                        self.recv_from_rpc.send_pyobj(output)
                else:
                    self.send_to_tokenizer.send_pyobj(output)
899
900
901
902
903

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
904
        # Create a new request
905
906
907
908
909
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
910
911
912
913
914
915
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

916
917
918
919
            if recv_req.bootstrap_port is None:
                # Use default bootstrap port
                recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port

920
921
922
923
924
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
925
926
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
927
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
928
                stream=recv_req.stream,
929
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
930
                input_embeds=recv_req.input_embeds,
Lianmin Zheng's avatar
Lianmin Zheng committed
931
                custom_logit_processor=recv_req.custom_logit_processor,
932
                return_hidden_states=recv_req.return_hidden_states,
933
                eos_token_ids=self.model_config.hf_eos_token_id,
934
                bootstrap_host=recv_req.bootstrap_host,
935
                bootstrap_port=recv_req.bootstrap_port,
936
                bootstrap_room=recv_req.bootstrap_room,
937
938
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
939

940
941
942
943
944
945
946
947
948
949
950
951
            if self.disaggregation_mode != DisaggregationMode.NULL:
                # Invalid request for disaggregated mode
                if recv_req.bootstrap_room is None:
                    error_message = (
                        f"Invalid request: Disaggregated request received without "
                        f"boostrap room id. {req.rid=}"
                    )
                    logger.error(error_message)
                    prepare_abort(req, error_message)
                    self.stream_output([req], req.return_logprob)
                    return

952
953
954
955
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
956
                req.finished_reason = FINISH_ABORT(
957
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
958
                )
959
                self._add_request_to_queue(req)
960
961
                return
        else:
962
963
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
964
            req = session.create_req(recv_req, self.tokenizer)
965
            if isinstance(req.finished_reason, FINISH_ABORT):
966
                self._add_request_to_queue(req)
967
                return
968

969
        # Handle multimodal inputs
Mick's avatar
Mick committed
970
971
        if recv_req.mm_inputs is not None:
            image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
972
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
973
            req.origin_input_ids = self.pad_input_ids_func(
974
                req.origin_input_ids, image_inputs
975
            )
976
            req.extend_image_inputs(image_inputs)
977

978
            if len(req.origin_input_ids) >= self.max_req_input_len:
979
                error_msg = (
980
                    "Multimodal prompt is too long after expanding multimodal tokens. "
981
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
982
                )
983
                logger.error(error_msg)
984
                req.origin_input_ids = [0]
Mick's avatar
Mick committed
985
                req.multimodal_inputs = None
986
                req.sampling_params.max_new_tokens = 0
987
                req.finished_reason = FINISH_ABORT(
988
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
989
                )
990
                self._add_request_to_queue(req)
991
992
                return

993
994
995
996
997
998
999
        # Validate prompts length
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
1000
1001
            req.origin_input_ids = [0]
            req.sampling_params.max_new_tokens = 0
1002
            self._add_request_to_queue(req)
1003
            return
1004

1005
        # Copy more attributes
1006
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
1007
1008
1009
1010
1011
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(req.origin_input_ids) - 1
        else:
            req.logprob_start_len = recv_req.logprob_start_len

1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
        if req.logprob_start_len >= len(req.origin_input_ids):
            req.finished_reason = FINISH_ABORT(
                f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
                HTTPStatus.BAD_REQUEST,
                "BadRequestError",
            )
            req.logprob_start_len = len(req.origin_input_ids) - 1
            self._add_request_to_queue(req)
            return

1022
1023
1024
1025
1026
1027
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
1028
            self.max_req_len - len(req.origin_input_ids) - 1,
1029
1030
        )

1031
1032
1033
1034
1035
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
1036
            or req.sampling_params.ebnf is not None
1037
            or req.sampling_params.structural_tag is not None
1038
1039
1040
1041
1042
1043
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
1044
1045
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
1046
1047
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
1048

1049
1050
1051
1052
1053
            value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
            req.grammar = value

            if not cache_hit:
                req.grammar_key = key
1054
1055
1056
                add_to_grammar_queue = True

        if add_to_grammar_queue:
1057
            req.queue_time_start = time.perf_counter()
1058
1059
            self.grammar_queue.append(req)
        else:
1060
1061
1062
            self._add_request_to_queue(req)

    def _add_request_to_queue(self, req: Req):
1063
        req.queue_time_start = time.perf_counter()
Byron Hsu's avatar
Byron Hsu committed
1064
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
Liangsheng Yin's avatar
Liangsheng Yin committed
1065
            self.disagg_prefill_bootstrap_queue.add(req)
Byron Hsu's avatar
Byron Hsu committed
1066
1067
1068
1069
1070
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.add(req)
        else:
            self.waiting_queue.append(req)

1071
1072
1073
1074
1075
    def _extend_requests_to_queue(self, reqs: List[Req]):
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            self.disagg_prefill_bootstrap_queue.extend(reqs)
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            # If this is a decode server, we put the request to the decode pending prealloc queue
Byron Hsu's avatar
Byron Hsu committed
1076
1077
1078
            self.disagg_decode_prealloc_queue.extend(reqs)
        else:
            self.waiting_queue.extend(reqs)
1079
1080
1081

    def handle_embedding_request(
        self,
1082
        recv_req: TokenizedEmbeddingReqInput,
1083
1084
1085
1086
1087
1088
1089
1090
1091
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

1092
1093
        # Handle multimodal inputs
        if recv_req.image_inputs is not None:
Mick's avatar
Mick committed
1094
            image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
            req.origin_input_ids = self.pad_input_ids_func(
                req.origin_input_ids, image_inputs
            )
            req.extend_image_inputs(image_inputs)

            if len(req.origin_input_ids) >= self.max_req_input_len:
                error_msg = (
                    "Multimodal prompt is too long after expanding multimodal tokens. "
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                )
                logger.error(error_msg)
                req.origin_input_ids = [0]
Mick's avatar
Mick committed
1108
                req.multimodal_inputs = None
1109
1110
1111
1112
                req.sampling_params.max_new_tokens = 0
                req.finished_reason = FINISH_ABORT(
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
                )
1113
                req.queue_time_start = time.perf_counter()
1114
1115
1116
                self.waiting_queue.append(req)
                return

1117
        # Validate prompts length
1118
        error_msg = validate_input_length(
1119
1120
1121
1122
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
1123
        if error_msg:
1124
            self._add_request_to_queue(req)
1125
            return
1126

1127
1128
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
1129
        self._add_request_to_queue(req)
1130

1131
1132
1133
1134
    def log_prefill_stats(
        self,
        adder: PrefillAdder,
        can_run_list: List[Req],
1135
        running_bs: int,
1136
    ):
1137
1138
        gap_latency = time.perf_counter() - self.last_prefill_stats_tic
        self.last_prefill_stats_tic = time.perf_counter()
Lianmin Zheng's avatar
Lianmin Zheng committed
1139
1140
1141
        self.last_input_throughput = self.num_prefill_tokens / gap_latency
        self.num_prefill_tokens = 0

1142
        num_used = self.max_total_num_tokens - (
1143
1144
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
1145
1146
        )

1147
        num_new_seq = len(can_run_list)
1148
        f = (
1149
            f"Prefill batch. "
1150
            f"#new-seq: {num_new_seq}, "
1151
1152
1153
1154
1155
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
1156
1157
1158
1159
1160
1161
1162
1163

        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
            f += f"#queue-req: {len(self.waiting_queue)}, "
            f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)} "
        else:
            f += f"#queue-req: {len(self.waiting_queue)}"

1164
        logger.info(f)
1165
1166

        if self.enable_metrics:
1167
1168
1169
            cache_hit_rate = adder.log_hit_tokens / (
                adder.log_input_tokens + adder.log_hit_tokens
            )
1170
1171
1172
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
1173
1174
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.stats.cache_hit_rate = cache_hit_rate
1175
1176
1177
1178
1179
1180

            total_queue_latency = 0
            for req in can_run_list:
                total_queue_latency += req.queue_time_end - req.queue_time_start
            self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq

1181
            self.metrics_collector.log_stats(self.stats)
1182
        self._publish_kv_events()
1183

1184
1185
1186
    def log_decode_stats(
        self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
    ):
1187
1188
        batch = running_batch or self.running_batch

1189
1190
        gap_latency = time.perf_counter() - self.last_decode_stats_tic
        self.last_decode_stats_tic = time.perf_counter()
1191
1192
        self.last_gen_throughput = self.num_generated_tokens / gap_latency
        self.num_generated_tokens = 0
1193
        num_running_reqs = len(batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1194
        num_used = self.max_total_num_tokens - (
1195
1196
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1197
        )
1198
1199
1200
1201
1202

        if RECORD_STEP_TIME:
            self.step_time_dict[num_running_reqs].append(
                gap_latency / self.server_args.decode_log_interval
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1203

Liangsheng Yin's avatar
Liangsheng Yin committed
1204
1205
1206
1207
1208
1209
1210
        msg = (
            f"Decode batch. "
            f"#running-req: {num_running_reqs}, "
            f"#token: {num_used}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
        )

1211
        if self.spec_algorithm.is_none():
1212
            spec_accept_length = 0
1213
        else:
1214
            spec_accept_length = (
1215
1216
                self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
            )
1217
1218
            self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
            self.cum_spec_accept_count += self.spec_num_total_forward_ct
1219
            self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
1220
1221
1222
1223
1224
1225
            msg += f"accept len: {spec_accept_length:.2f}, "

        if self.disaggregation_mode == DisaggregationMode.DECODE:
            msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "

        msg += (
1226
            f"cuda graph: {can_run_cuda_graph}, "
Liangsheng Yin's avatar
Liangsheng Yin committed
1227
1228
1229
            f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
            f"#queue-req: {len(self.waiting_queue)}"
        )
1230
1231

        logger.info(msg)
1232
1233
1234
1235
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
1236
1237
            self.stats.cache_hit_rate = 0.0
            self.stats.gen_throughput = self.last_gen_throughput
1238
            self.stats.num_queue_reqs = len(self.waiting_queue)
1239
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1240
            self.stats.spec_accept_length = spec_accept_length
1241
            self.metrics_collector.log_stats(self.stats)
1242
        self._publish_kv_events()
1243

Lianmin Zheng's avatar
Lianmin Zheng committed
1244
1245
    def check_memory(self):
        available_size = (
1246
1247
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1248
        )
1249
1250
1251
1252
1253
1254
1255
        protected_size = self.tree_cache.protected_size()
        memory_leak = available_size != (
            self.max_total_num_tokens
            if not self.enable_hierarchical_cache
            else self.max_total_num_tokens - protected_size
        )
        if memory_leak:
1256
            msg = (
1257
                "token_to_kv_pool_allocator memory leak detected! "
1258
                f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1259
1260
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1261
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1262
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1263
1264

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
1265
            msg = (
1266
                "req_to_token_pool memory leak detected!"
1267
1268
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1269
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1270
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1271

1272
1273
1274
        if (
            self.enable_metrics
            and self.attn_tp_rank == 0
1275
            and time.perf_counter() > self.metrics_collector.last_log_time + 30
1276
1277
1278
        ):
            # During idle time, also collect metrics every 30 seconds.
            num_used = self.max_total_num_tokens - (
1279
                self.token_to_kv_pool_allocator.available_size()
1280
1281
                + self.tree_cache.evictable_size()
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1282
            num_running_reqs = len(self.running_batch.reqs)
1283
1284
1285
1286
1287
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
1288
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1289
            self.metrics_collector.log_stats(self.stats)
1290
        self._publish_kv_events()
1291

1292
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1293
        # Merge the prefill batch into the running batch
1294
1295
1296
1297
1298
1299
1300
1301
        chunked_req_to_exclude = set()
        if self.chunked_req:
            # Move the chunked request out of the batch so that we can merge
            # only finished requests to running_batch.
            chunked_req_to_exclude.add(self.chunked_req)
            self.tree_cache.cache_unfinished_req(self.chunked_req)
            # chunked request keeps its rid but will get a new req_pool_idx
            self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
Lianmin Zheng's avatar
Lianmin Zheng committed
1302
        if self.last_batch and self.last_batch.forward_mode.is_extend():
1303
1304
1305
1306
            if self.last_batch.chunked_req is not None:
                # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
                # We need to discard it.
                chunked_req_to_exclude.add(self.last_batch.chunked_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1307

1308
            # Filter batch
1309
            last_bs = self.last_batch.batch_size()
1310
1311
1312
            self.last_batch.filter_batch(
                chunked_req_to_exclude=list(chunked_req_to_exclude)
            )
1313
            if self.last_batch.batch_size() < last_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1314
                self.running_batch.batch_is_full = False
1315

1316
            # Merge the new batch into the running batch
1317
            if not self.last_batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1318
                if self.running_batch.is_empty():
1319
1320
                    self.running_batch = self.last_batch
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1321
                    # Merge running_batch with prefill batch
1322
                    self.running_batch.merge_batch(self.last_batch)
1323

1324
1325
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
1326
1327
1328
1329
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
1330
            if not self.running_batch.is_empty():
1331
                self.running_batch = self.update_running_batch(self.running_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1332
1333
1334
                ret = self.running_batch if not self.running_batch.is_empty() else None
            else:
                ret = None
1335

1336
        # Handle DP attention
1337
        if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
Lianmin Zheng's avatar
Lianmin Zheng committed
1338
            ret, _ = self.prepare_dp_attn_batch(ret)
1339
1340

        return ret
1341

1342
1343
1344
1345
1346
1347
    def get_num_allocatable_reqs(self, running_bs):
        res = global_server_args_dict["max_micro_batch_size"] - running_bs
        if self.pp_size > 1:
            res = min(res, self.req_to_token_pool.available_size())
        return res

Lianmin Zheng's avatar
Lianmin Zheng committed
1348
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
1349
        # Check if the grammar is ready in the grammar queue
1350
        if self.grammar_queue:
1351
            self.move_ready_grammar_requests()
1352

Lianmin Zheng's avatar
Lianmin Zheng committed
1353
1354
        # Handle the cases where prefill is not allowed
        if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1355
            self.running_batch.batch_is_full or len(self.waiting_queue) == 0
1356
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1357
1358
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1359
        running_bs = len(self.running_batch.reqs)
1360
        # Ignore the check if self.chunked_req is not None.
1361
1362
1363
1364
1365
        # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
        # as the space for the chunked request has just been released.
        # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
        # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
        if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
Lianmin Zheng's avatar
Lianmin Zheng committed
1366
            self.running_batch.batch_is_full = True
1367
1368
            return None

1369
1370
1371
1372
1373
        if self.enable_hierarchical_cache:
            # check for completion of hierarchical cache activities to release memory
            self.tree_cache.writing_check()
            self.tree_cache.loading_check()

1374
1375
1376
        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
1377
        # Prefill policy
1378
1379
        adder = PrefillAdder(
            self.tree_cache,
1380
            self.token_to_kv_pool_allocator,
1381
1382
1383
1384
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1385
            running_bs if self.is_mixed_chunk else 0,
1386
1387
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1388
        if self.chunked_req is not None:
1389
1390
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1391

Lianmin Zheng's avatar
Lianmin Zheng committed
1392
        if self.lora_paths:
Lianmin Zheng's avatar
Lianmin Zheng committed
1393
1394
            lora_set = set([req.lora_path for req in self.running_batch.reqs])

1395
        # Get requests from the waiting queue to a new prefill batch
1396
1397
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1398
                self.lora_paths
1399
1400
1401
1402
1403
1404
1405
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1406
                self.running_batch.batch_is_full = True
1407
1408
                break

1409
            if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
Lianmin Zheng's avatar
Lianmin Zheng committed
1410
                self.running_batch.batch_is_full = True
1411
                break
1412

Byron Hsu's avatar
Byron Hsu committed
1413
1414
1415
1416
1417
1418
1419
            if self.disaggregation_mode == DisaggregationMode.PREFILL:
                # In prefill mode, prealloc queue and transfer queue can also take memory,
                # so we need to check if the available size for the actual available size.
                if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
                    self.running_batch.batch_is_full = True
                    break

1420
1421
1422
1423
            req.init_next_round_input(
                None if prefix_computed else self.tree_cache,
                self.enable_hierarchical_cache,
            )
1424

1425
1426
1427
            res = adder.add_one_req(
                req, self.chunked_req, self.enable_hierarchical_cache
            )
1428

1429
1430
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1431
1432
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
Lianmin Zheng's avatar
Lianmin Zheng committed
1433
1434
                        self.running_batch.batch_is_full = len(
                            adder.can_run_list
1435
                        ) > 0 or (not self.running_batch.is_empty())
1436
                    else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1437
                        self.running_batch.batch_is_full = True
1438
1439
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1440
        # Update waiting queue
1441
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1442
1443
        if len(can_run_list) == 0:
            return None
1444
1445
1446
1447

        if self.enable_metrics:
            # only record queue time when enable_metrics is True to avoid overhead
            for req in can_run_list:
1448
                req.queue_time_end = time.perf_counter()
1449

Lianmin Zheng's avatar
Lianmin Zheng committed
1450
1451
1452
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1453

1454
        if self.enable_hierarchical_cache:
1455
            self.tree_cache.ready_to_load_cache()
1456

1457
1458
1459
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1460

1461
1462
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1463

1464
        # Print stats
1465
        if self.attn_tp_rank == 0:
1466
            self.log_prefill_stats(adder, can_run_list, running_bs)
1467

Lianmin Zheng's avatar
Lianmin Zheng committed
1468
        # Create a new batch
1469
1470
1471
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
1472
            self.token_to_kv_pool_allocator,
1473
            self.tree_cache,
1474
            self.model_config,
1475
            self.enable_overlap,
1476
            self.spec_algorithm,
1477
            self.server_args.enable_custom_logit_processor,
1478
            chunked_req=self.chunked_req,
1479
        )
1480
        new_batch.prepare_for_extend()
1481

Lianmin Zheng's avatar
Lianmin Zheng committed
1482
        # Mixed-style chunked prefill
1483
1484
        if (
            self.is_mixed_chunk
Lianmin Zheng's avatar
Lianmin Zheng committed
1485
            and not self.running_batch.is_empty()
1486
1487
1488
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1489
1490
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1491
                self.running_batch.prepare_for_decode()
1492
1493
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1494
1495
1496
            self.running_batch = ScheduleBatch(
                reqs=[], batch_is_full=self.running_batch.batch_is_full
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1497
1498
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1499
1500
1501

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1502
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1503
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1504
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1505

1506
1507
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1508
1509
            batch.batch_is_full = False
            return batch
1510

Lianmin Zheng's avatar
Lianmin Zheng committed
1511
        # Check if decode out of memory
1512
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1513
            TEST_RETRACT and batch.batch_size() > 10
1514
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1515
1516
            old_ratio = self.new_token_ratio

1517
            retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
Lianmin Zheng's avatar
Lianmin Zheng committed
1518
            self.new_token_ratio = new_token_ratio
1519

Lianmin Zheng's avatar
Lianmin Zheng committed
1520
1521
1522
1523
1524
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
1525
            self._extend_requests_to_queue(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1526
1527
        else:
            self.new_token_ratio = max(
1528
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1529
1530
1531
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1532
        if batch.batch_size() < initial_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1533
            batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1534
1535

        # Update batch tensors
1536
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1537
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1538

1539
1540
1541
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1542
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1543
1544
        self.forward_ct += 1

1545
1546
1547
1548
1549
        # Check profiler
        if (
            self.profiler_target_forward_ct
            and self.profiler_target_forward_ct <= self.forward_ct
        ):
1550
            self.send_to_tokenizer.send_pyobj(self.stop_profile())
1551

1552
1553
1554
1555
        if self.forward_sleep_time is not None:
            logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
            time.sleep(self.forward_sleep_time)

1556
        # Run forward
1557
        if self.is_generation:
1558
1559
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
1560
                if self.pp_group.is_last_rank:
1561
                    logits_output, next_token_ids, can_run_cuda_graph = (
1562
1563
1564
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
                else:
1565
                    pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
1566
1567
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
1568
                bid = model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
1569
            else:
1570
1571
1572
                (
                    logits_output,
                    next_token_ids,
1573
                    bid,
1574
                    num_accepted_tokens,
1575
                    can_run_cuda_graph,
1576
1577
1578
1579
1580
1581
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
                self.spec_num_total_accepted_tokens += (
                    num_accepted_tokens + batch.batch_size()
                )
                self.spec_num_total_forward_ct += batch.batch_size()
                self.num_generated_tokens += num_accepted_tokens
1582
1583
1584

            if self.pp_group.is_last_rank:
                batch.output_ids = next_token_ids
1585

1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
            if batch.return_logprob:
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_input_len_per_req = None
                extend_logprob_start_len_per_req = None

1598
            ret = GenerationBatchResult(
1599
1600
1601
1602
1603
1604
1605
                logits_output=logits_output if self.pp_group.is_last_rank else None,
                pp_hidden_states_proxy_tensors=(
                    pp_hidden_states_proxy_tensors
                    if not self.pp_group.is_last_rank
                    else None
                ),
                next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
1606
1607
                extend_input_len_per_req=extend_input_len_per_req,
                extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1608
                bid=bid,
1609
                can_run_cuda_graph=can_run_cuda_graph,
1610
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1611
1612
1613
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1614
1615
1616
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1617
        return ret
Chayenne's avatar
Chayenne committed
1618

1619
1620
1621
1622
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
1623
        launch_done: Optional[threading.Event] = None,
1624
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1625
        if batch.forward_mode.is_decode():
1626
            self.process_batch_result_decode(batch, result, launch_done)
1627
        elif batch.forward_mode.is_extend():
1628
            self.process_batch_result_prefill(batch, result, launch_done)
1629
1630
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1631
                self.tp_worker.resolve_last_batch_result(launch_done)
1632
                self.set_next_batch_sampling_info_done(batch)
1633
        elif batch.forward_mode.is_dummy_first():
1634
            self.set_next_batch_sampling_info_done(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1635

1636
1637
1638
1639
1640
1641
1642
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

1643
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
1644
1645
1646
1647
        return self.prepare_dp_attn_batch_raw(
            local_batch,
            dp_size=self.server_args.dp_size,
            attn_tp_size=self.attn_tp_size,
1648
            moe_dense_tp_size=self.server_args.moe_dense_tp_size,
1649
1650
1651
1652
1653
            tp_cpu_group=self.tp_cpu_group,
            get_idle_batch=self.get_idle_batch,
            disable_cuda_graph=self.server_args.disable_cuda_graph,
            spec_algorithm=self.spec_algorithm,
            speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1654
1655
1656
            enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
            enable_deepep_moe=self.server_args.enable_deepep_moe,
            deepep_mode=DeepEPMode[self.server_args.deepep_mode],
1657
1658
1659
1660
1661
1662
1663
        )

    @staticmethod
    def prepare_dp_attn_batch_raw(
        local_batch: ScheduleBatch,
        dp_size,
        attn_tp_size: int,
1664
        moe_dense_tp_size: Optional[int],
1665
1666
1667
1668
1669
        tp_cpu_group,
        get_idle_batch,
        disable_cuda_graph: bool,
        spec_algorithm,
        speculative_num_draft_tokens,
1670
1671
1672
        enable_two_batch_overlap: bool,
        enable_deepep_moe: bool,
        deepep_mode: DeepEPMode,
1673
    ):
1674
1675
1676
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
1677
            num_tokens_for_logprob = 0
1678
1679
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
1680
1681
            if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
                num_tokens = num_tokens * speculative_num_draft_tokens
1682
            num_tokens_for_logprob = num_tokens
1683
1684
        else:
            num_tokens = local_batch.extend_num_tokens
1685
            num_tokens_for_logprob = sum(
Lianmin Zheng's avatar
Lianmin Zheng committed
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
                [
                    # We should have at least 1 token for sample in every case.
                    max(extend_len - logprob_start_len, 1)
                    for logprob_start_len, extend_len in zip(
                        local_batch.extend_logprob_start_lens, local_batch.extend_lens
                    )
                ]
            )

        if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
            can_cuda_graph = 1
        else:
            can_cuda_graph = 0

1700
        if not spec_algorithm.is_none():
1701
            # TODO(sang): Support cuda graph when idle batch is there.
Lianmin Zheng's avatar
Lianmin Zheng committed
1702
1703
            if local_batch is None or local_batch.forward_mode.is_idle():
                can_cuda_graph = 0
1704

Lianmin Zheng's avatar
Lianmin Zheng committed
1705
1706
1707
        is_extend_in_batch = (
            local_batch.forward_mode.is_extend() if local_batch else False
        )
1708
1709
1710

        tbo_preparer = TboDPAttentionPreparer()

Lianmin Zheng's avatar
Lianmin Zheng committed
1711
1712
1713
1714
        local_info = torch.tensor(
            [
                num_tokens,
                can_cuda_graph,
1715
                num_tokens_for_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
1716
                is_extend_in_batch,
1717
1718
1719
1720
1721
1722
                *tbo_preparer.prepare_all_gather(
                    local_batch,
                    deepep_mode,
                    enable_deepep_moe,
                    enable_two_batch_overlap,
                ),
Lianmin Zheng's avatar
Lianmin Zheng committed
1723
1724
1725
1726
            ],
            dtype=torch.int64,
        )
        global_info = torch.empty(
1727
            (dp_size, attn_tp_size, 6),
Lianmin Zheng's avatar
Lianmin Zheng committed
1728
1729
            dtype=torch.int64,
        )
1730
        torch.distributed.all_gather_into_tensor(
Lianmin Zheng's avatar
Lianmin Zheng committed
1731
1732
            global_info.flatten(),
            local_info,
1733
            group=tp_cpu_group,
1734
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1735
1736
1737
1738
        global_num_tokens = global_info[:, 0, 0].tolist()
        can_cuda_graph = min(global_info[:, 0, 1].tolist())
        global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
        is_extend_in_batch = global_info[:, 0, 3].tolist()
1739

1740
1741
1742
1743
        tbo_split_seq_index, global_forward_mode = tbo_preparer.compute_output(
            global_info[:, :, 4:6]
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1744
        if local_batch is None and max(global_num_tokens) > 0:
1745
            local_batch = get_idle_batch()
1746
1747

        if local_batch is not None:
1748
1749
1750
1751
1752
1753
1754
1755
1756
            # TODO: handle the case when moe_dense_tp_size != 1
            if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
                local_batch.global_num_tokens = [num_tokens]
                local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
            else:
                local_batch.global_num_tokens = global_num_tokens
                local_batch.global_num_tokens_for_logprob = (
                    global_num_tokens_for_logprob
                )
1757
1758
            local_batch.tbo_split_seq_index = tbo_split_seq_index
            local_batch.global_forward_mode = global_forward_mode
1759

1760
            # Check forward mode for cuda graph
1761
            if not disable_cuda_graph:
Lianmin Zheng's avatar
Lianmin Zheng committed
1762
                local_batch.can_run_dp_cuda_graph = can_cuda_graph
1763

Lianmin Zheng's avatar
Lianmin Zheng committed
1764
        return local_batch, any(is_extend_in_batch)
1765
1766
1767
1768
1769

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
1770
            self.token_to_kv_pool_allocator,
1771
1772
1773
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
1774
            self.spec_algorithm,
1775
            self.server_args.enable_custom_logit_processor,
1776
1777
1778
1779
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1780
1781
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
1782

1783
        num_ready_reqs = 0
1784
        num_abort_reqs = 0
1785
1786
        for req in self.grammar_queue:
            try:
1787
1788
1789
                req.grammar = req.grammar.result(timeout=0.03)
                if req.grammar:
                    self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
1790
1791
                num_ready_reqs += 1
            except futures._base.TimeoutError:
1792
1793
1794
                req.grammar_wait_ct += 1
                if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
                    num_abort_reqs = 1
1795
1796
                break

1797
        if self.server_args.enable_dp_attention:
1798
1799
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
1800
        else:
1801
1802
1803
1804
1805
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
1806
            tensor = torch.tensor([num_ready_reqs, num_abort_reqs], dtype=torch.int32)
1807
1808
1809
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
1810
1811
            num_ready_reqs_max, num_abort_reqs_max = tensor.tolist()

1812
            for i in range(num_ready_reqs, num_ready_reqs_max):
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
                req = self.grammar_queue[i]
                req.grammar = req.grammar.result()
                if req.grammar:
                    self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())

            for i in range(num_ready_reqs, num_ready_reqs + num_abort_reqs_max):
                req = self.grammar_queue[i]
                req.grammar.cancel()
                req.grammar = None
                error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
                logger.error(error_msg)
                req.finished_reason = FINISH_ABORT(
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
                )
            num_ready_reqs = num_ready_reqs_max + num_abort_reqs_max
1828

1829
        self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1830
1831
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1832
1833
1834
1835
1836
1837
1838
    def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
        if batch.next_batch_sampling_info:
            if batch.next_batch_sampling_info.grammars is not None:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
                self.current_stream.synchronize()
            batch.next_batch_sampling_info.sampling_info_done.set()

1839
1840
1841
    def watchdog_thread(self):
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
        self.watchdog_last_forward_ct = 0
1842
        self.watchdog_last_time = time.perf_counter()
1843
1844

        while True:
1845
            current = time.perf_counter()
1846
1847
1848
1849
1850
1851
1852
1853
1854
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if current > self.watchdog_last_time + self.watchdog_timeout:
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)

Lianmin Zheng's avatar
Lianmin Zheng committed
1855
1856
1857
1858
1859
1860
1861
1862
1863
        if not disable_request_logging():
            # Print batch size and memory pool info to check whether there are de-sync issues.
            logger.error(
                f"{self.cur_batch.batch_size()=}, "
                f"{self.cur_batch.reqs=}, "
                f"{self.token_to_kv_pool_allocator.available_size()=}, "
                f"{self.tree_cache.evictable_size()=}, "
            )

1864
        pyspy_dump_schedulers()
Lianmin Zheng's avatar
Lianmin Zheng committed
1865
        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
1866
1867
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
1868
1869

        # Wait for some time so that the parent process can print the error.
1870
1871
1872
        time.sleep(5)
        self.parent_process.send_signal(signal.SIGQUIT)

1873
1874
1875
    def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
        success = self.flush_cache()
        return FlushCacheReqOutput(success=success)
1876

1877
    def flush_cache(self):
1878
        """Flush the memory pool and cache."""
1879
1880
1881
1882
1883
        if (
            len(self.waiting_queue) == 0
            and self.running_batch.is_empty()
            and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
        ):
1884
1885
            self.cur_batch = None
            self.last_batch = None
1886
            self.tree_cache.reset()
1887
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1888
                self.grammar_backend.reset()
1889
            self.req_to_token_pool.clear()
1890
            self.token_to_kv_pool_allocator.clear()
1891
1892
1893

            if not self.spec_algorithm.is_none():
                self.draft_worker.model_runner.req_to_token_pool.clear()
1894
                self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
1895
1896
1897
1898
1899

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
1900
1901
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
1902
1903
1904
1905
1906
1907
1908
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
1909
                f"#running-req: {len(self.running_batch.reqs)}"
1910
1911
1912
1913
            )
            if_success = False
        return if_success

Liangsheng Yin's avatar
Liangsheng Yin committed
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
    def get_load(self):
        # TODO(lsyin): use dynamically maintained num_waiting_tokens
        load = (
            self.max_total_num_tokens
            - self.token_to_kv_pool_allocator.available_size()
            - self.tree_cache.evictable_size()
        )
        load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            load += sum(
                len(req.origin_input_ids)
                for req in self.disagg_prefill_bootstrap_queue.queue
            )
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            load += sum(
                len(req.req.origin_input_ids)
                for req in self.disagg_decode_prealloc_queue.queue
            )

        return load

1935
1936
1937
1938
1939
1940
1941
1942
1943
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )
        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
Liangsheng Yin's avatar
Liangsheng Yin committed
1944
1945
1946
1947

        ret["load"] = self.get_load()

        return GetInternalStateReqOutput(internal_state=ret)
1948
1949
1950
1951
1952

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
1953
                "max_micro_batch_size",
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
1964
1965
1966
1967
1968
1969
1970
1971
            elif k == "max_micro_batch_size" and (
                v > self.max_running_requests // self.pp_size or v < 1
            ):
                logging.warning(
                    f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
                )
                if_success = False
                break
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
            logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
    def handle_rpc_request(self, recv_req: RpcReqInput):
        # Handle RPC requests
        logger.info(
            f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
        )

        success = True
        exec = None
        try:
            func = getattr(self, recv_req.method)
            func(recv_req.parameters)
        except Exception as e:
            success = False
            exec = e
            logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")

        barrier()
        return RpcReqOutput(success, "" if not exec else str(exec))

    def save_remote_model(self, params):
        url = params["url"]

2009
        worker = self.tp_worker.worker
2010
2011
2012
2013

        worker.model_runner.save_remote_model(url)

    def save_sharded_model(self, params):
2014
        worker = self.tp_worker.worker
2015
2016
2017
2018
2019
2020
2021

        worker.model_runner.save_sharded_model(
            path=params["path"],
            pattern=params["pattern"],
            max_size=params["max_size"],
        )

2022
    def abort_request(self, recv_req: AbortReq):
Lianmin Zheng's avatar
Lianmin Zheng committed
2023
2024
        # TODO(lmzheng): abort the requests in the grammar queue.

2025
        # Delete requests in the waiting queue
Lianmin Zheng's avatar
Lianmin Zheng committed
2026
        to_del = []
2027
        for i, req in enumerate(self.waiting_queue):
Lianmin Zheng's avatar
Lianmin Zheng committed
2028
2029
            if req.rid.startswith(recv_req.rid):
                to_del.append(i)
2030

Lianmin Zheng's avatar
Lianmin Zheng committed
2031
        # Sort in reverse order to avoid index issues when deleting
Lianmin Zheng's avatar
Lianmin Zheng committed
2032
        for i in reversed(to_del):
Lianmin Zheng's avatar
Lianmin Zheng committed
2033
            req = self.waiting_queue.pop(i)
Lianmin Zheng's avatar
Lianmin Zheng committed
2034
            self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
2035
            logger.debug(f"Abort queued request. {req.rid=}")
2036
2037

        # Delete requests in the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
2038
2039
2040
2041
2042
2043
        if self.cur_batch is self.running_batch or self.cur_batch is None:
            reqs = self.running_batch.reqs
        else:
            reqs = self.running_batch.reqs + self.cur_batch.reqs

        for req in reqs:
Lianmin Zheng's avatar
Lianmin Zheng committed
2044
2045
2046
            if req.rid.startswith(recv_req.rid) and not req.finished():
                logger.debug(f"Abort running request. {req.rid=}")
                req.to_abort = True
2047

2048
2049
2050
    def _pause_engine(self) -> Tuple[List[Req], int]:
        raise NotImplementedError()

Chayenne's avatar
Chayenne committed
2051
2052
2053
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
2054
2055
2056
2057
2058
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
2059
        return UpdateWeightFromDiskReqOutput(success, message, 0)
2060

2061
2062
2063
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
2064
        return InitWeightsUpdateGroupReqOutput(success, message)
2065
2066

    def update_weights_from_distributed(
2067
2068
2069
        self,
        recv_req: UpdateWeightsFromDistributedReqInput,
    ) -> Tuple[bool, str]:
2070
2071
2072
2073
2074
2075
2076
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
2077
        return UpdateWeightsFromDistributedReqOutput(success, message)
2078

2079
2080
2081
2082
2083
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        """Update the online model parameter from tensors."""
        success, message = self.tp_worker.update_weights_from_tensor(recv_req)
        # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
        if success:
2084
2085
2086
            if recv_req.flush_cache:
                flash_cache_success = self.flush_cache()
                assert flash_cache_success, "Cache flush failed after updating weights"
2087
2088
        else:
            logger.error(message)
2089
        return UpdateWeightsFromTensorReqOutput(success, message)
2090

2091
2092
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
2093
        return GetWeightsByNameReqOutput(parameter)
2094

2095
    def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
2096
2097
2098
        self.memory_saver_adapter.check_validity(
            caller_name="release_memory_occupation"
        )
2099
2100
2101
2102
2103
        self.stashed_model_static_state = _export_static_state(
            self.tp_worker.worker.model_runner.model
        )
        self.memory_saver_adapter.pause()
        self.flush_cache()
2104
        return ReleaseMemoryOccupationReqOutput()
2105

2106
    def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
2107
        self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
2108
2109
2110
2111
2112
        self.memory_saver_adapter.resume()
        _import_static_state(
            self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
        )
        del self.stashed_model_static_state
2113
2114
        return ResumeMemoryOccupationReqOutput()

2115
2116
2117
2118
2119
2120
2121
    def slow_down(self, recv_req: SlowDownReqInput):
        t = recv_req.forward_sleep_time
        if t is not None and t <= 0:
            t = None
        self.forward_sleep_time = t
        return SlowDownReqOutput()

2122
    def profile(self, recv_req: ProfileReq):
2123
2124
        if recv_req.type == ProfileReqType.START_PROFILE:
            return self.start_profile(
2125
2126
2127
2128
2129
                recv_req.output_dir,
                recv_req.num_steps,
                recv_req.activities,
                recv_req.with_stack,
                recv_req.record_shapes,
2130
                recv_req.profile_id,
2131
            )
2132
        else:
2133
2134
2135
2136
2137
2138
2139
            return self.stop_profile()

    def start_profile(
        self,
        output_dir: Optional[str],
        num_steps: Optional[int],
        activities: Optional[List[str]],
2140
2141
        with_stack: Optional[bool],
        record_shapes: Optional[bool],
2142
        profile_id: Optional[str],
2143
    ) -> None:
2144
        if self.profiler_activities:
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
            return ProfileReqOutput(
                success=False,
                message="Profiling is already in progress. Call /stop_profile first.",
            )

        if output_dir is None:
            output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
        if activities is None:
            activities = ["CPU", "GPU"]

        self.torch_profiler_output_dir = output_dir
2156
        self.profiler_activities = activities
2157
        self.profiler_id = profile_id
2158
        logger.info(
2159
            "Profiling starts. Traces will be saved to: %s (with id %s)",
2160
            self.torch_profiler_output_dir,
2161
            self.profiler_id,
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
        )

        activity_map = {
            "CPU": torch.profiler.ProfilerActivity.CPU,
            "GPU": torch.profiler.ProfilerActivity.CUDA,
        }
        torchprof_activities = [
            activity_map[a] for a in activities if a in activity_map
        ]

        if torchprof_activities:
            self.torch_profiler = torch.profiler.profile(
                activities=torchprof_activities,
2175
2176
                with_stack=with_stack if with_stack is not None else True,
                record_shapes=record_shapes if record_shapes is not None else False,
2177
2178
2179
2180
2181
            )
            self.torch_profiler.start()

        if "MEM" in activities:
            torch.cuda.memory._record_memory_history(max_entries=100000)
2182

2183
2184
2185
        if "CUDA_PROFILER" in activities:
            torch.cuda.cudart().cudaProfilerStart()

2186
2187
2188
2189
2190
2191
        if num_steps:
            self.profiler_target_forward_ct = self.forward_ct + num_steps
            # The caller will be notified when reaching profiler_target_forward_ct
        else:
            self.profiler_target_forward_ct = None
            return ProfileReqOutput(success=True, message="Succeeded")
2192
2193

    def stop_profile(self) -> None:
2194
        if self.profiler_activities is None:
2195
2196
2197
2198
            return ProfileReqOutput(
                success=False,
                message="Profiling is not in progress. Call /start_profile first.",
            )
2199
2200
2201
2202
2203
2204
2205

        logger.info("Stop profiling...")
        if self.torch_profiler is not None:
            self.torch_profiler.stop()
            self.torch_profiler.export_chrome_trace(
                os.path.join(
                    self.torch_profiler_output_dir,
2206
                    self.profiler_id + f"-TP-{self.tp_rank}" + ".trace.json.gz",
2207
2208
2209
                )
            )

2210
        if "MEM" in self.profiler_activities:
2211
            memory_profile_path = os.path.join(
2212
                self.torch_profiler_output_dir,
2213
                self.profiler_id + f"-TP-{self.tp_rank}-memory" + ".pickle",
2214
2215
2216
2217
            )
            torch.cuda.memory._dump_snapshot(memory_profile_path)
            torch.cuda.memory._record_memory_history(enabled=None)

2218
2219
2220
        if "CUDA_PROFILER" in self.profiler_activities:
            torch.cuda.cudart().cudaProfilerStop()

2221
2222
2223
        logger.info(
            "Profiling done. Traces are saved to: %s",
            self.torch_profiler_output_dir,
2224
        )
2225
2226
        self.torch_profiler = None
        self.torch_profiler_output_dir = None
2227
        self.profiler_activities = None
2228
        self.profiler_target_forward_ct = None
2229

2230
        return ProfileReqOutput(success=True, message="Succeeded")
2231

2232
2233
    def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
        if recv_req == ExpertDistributionReq.START_RECORD:
2234
            get_global_expert_distribution_recorder().start_record()
2235
        elif recv_req == ExpertDistributionReq.STOP_RECORD:
2236
            get_global_expert_distribution_recorder().stop_record()
2237
        elif recv_req == ExpertDistributionReq.DUMP_RECORD:
2238
            get_global_expert_distribution_recorder().dump_record()
2239
2240
        else:
            raise ValueError("Unrecognized ExpertDistributionReq value")
2241
        return ExpertDistributionReqOutput()
2242

2243
    def open_session(self, recv_req: OpenSessionReqInput):
2244
2245
2246
2247
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
2248
            return OpenSessionReqOutput(session_id, False)
2249
        elif session_id is None:
2250
            logger.warning("session id is None, cannot open.")
2251
            return OpenSessionReqOutput(session_id, False)
2252
2253
2254
2255
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
2256
            return OpenSessionReqOutput(session_id, True)
2257
2258
2259
2260
2261
2262
2263
2264
2265

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

2266
2267
    def get_print_prefix(self):
        prefix = ""
2268
2269
        if self.attn_dp_rank is not None:
            prefix += f" DP{self.attn_dp_rank}"
2270
2271
2272
2273
2274
2275
        if self.server_args.tp_size > 1:
            prefix += f" TP{self.tp_rank}"
        if self.pp_size > 1:
            prefix += f" PP{self.pp_rank}"
        return prefix

2276
2277
2278
2279
2280
2281
2282
    def _publish_kv_events(self):
        if self.enable_kv_cache_events:
            events = self.tree_cache.take_events()
            if events:
                batch = KVEventBatch(ts=time.time(), events=events)
                self.kv_event_publisher.publish(batch)

2283

2284
2285
2286
2287
def is_health_check_generate_req(recv_req):
    return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")


2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
def _export_static_state(model):
    return dict(
        buffers=[
            (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
        ]
    )


def _import_static_state(model, static_params):
    self_named_buffers = dict(model.named_buffers())
    for name, tensor in static_params["buffers"]:
        self_named_buffers[name][...] = tensor


2302
2303
2304
2305
2306
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
2307
    pp_rank: int,
2308
    dp_rank: Optional[int],
2309
    pipe_writer,
2310
):
2311
    # Generate the prefix
2312
2313
2314
2315
2316
2317
2318
    prefix = ""
    if dp_rank is not None:
        prefix += f" DP{dp_rank}"
    if server_args.tp_size > 1:
        prefix += f" TP{tp_rank}"
    if server_args.pp_size > 1:
        prefix += f" PP{pp_rank}"
2319

2320
    # Config the process
2321
    kill_itself_when_parent_died()
2322
    setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
2323
    faulthandler.enable()
2324
    parent_process = psutil.Process().parent()
2325

2326
2327
2328
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
2329

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
2330
    # Configure the logger
2331
    configure_logger(server_args, prefix=prefix)
2332
    suppress_other_loggers()
2333

2334
    # Set cpu affinity to this gpu process
2335
2336
2337
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

2338
2339
2340
2341
    embedding_cache_size = 100
    if "SGLANG_VLM_CACHE_SIZE_MB" in os.environ:
        embedding_cache_size = int(os.environ["SGLANG_VLM_CACHE_SIZE_MB"])
    init_embedding_cache(embedding_cache_size * 1024 * 1024)
2342
    # Create a scheduler and run the event loop
2343
    try:
2344
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
2345
        pipe_writer.send(
Mick's avatar
Mick committed
2346
2347
2348
2349
2350
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2351
        )
Byron Hsu's avatar
Byron Hsu committed
2352
2353
2354
        disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode

        if disaggregation_mode == DisaggregationMode.NULL:
2355
2356
2357
            if server_args.pp_size > 1:
                scheduler.event_loop_pp()
            elif scheduler.enable_overlap:
Byron Hsu's avatar
Byron Hsu committed
2358
2359
2360
2361
                scheduler.event_loop_overlap()
            else:
                scheduler.event_loop_normal()
        elif disaggregation_mode == DisaggregationMode.PREFILL:
2362
2363
2364
2365
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_prefill()
            else:
                scheduler.event_loop_normal_disagg_prefill()
2366

Byron Hsu's avatar
Byron Hsu committed
2367
        elif disaggregation_mode == DisaggregationMode.DECODE:
2368
2369
2370
2371
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_decode()
            else:
                scheduler.event_loop_normal_disagg_decode()
Byron Hsu's avatar
Byron Hsu committed
2372

2373
    except Exception:
2374
2375
2376
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)