scheduler.py 91.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
import time
23
from collections import defaultdict, deque
Lianmin Zheng's avatar
Lianmin Zheng committed
24
from concurrent import futures
25
from dataclasses import dataclass
26
from http import HTTPStatus
27
from types import SimpleNamespace
28
from typing import Dict, List, Optional, Tuple, Union
29

30
import psutil
31
import setproctitle
32
import torch
33
import zmq
34
from torch.distributed import barrier
35

36
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
37
from sglang.srt.configs.model_config import ModelConfig
38
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
Byron Hsu's avatar
Byron Hsu committed
39
40
41
42
43
from sglang.srt.disaggregation.decode import (
    DecodePreallocQueue,
    DecodeTransferQueue,
    SchedulerDisaggregationDecodeMixin,
)
44
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
Byron Hsu's avatar
Byron Hsu committed
45
46
47
48
49
50
51
from sglang.srt.disaggregation.prefill import (
    PrefillBootstrapQueue,
    SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
    ReqToMetadataIdxAllocator,
52
    TransferBackend,
53
    prepare_abort,
Byron Hsu's avatar
Byron Hsu committed
54
)
55
from sglang.srt.distributed import get_pp_group, get_world_group
xm:D's avatar
xm:D committed
56
57
58
59
60
from sglang.srt.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
61
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
62
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
63
64
65
66
from sglang.srt.managers.expert_distribution import (
    ExpertDistributionRecorder,
    get_global_expert_distribution_recorder,
)
67
68
from sglang.srt.managers.io_struct import (
    AbortReq,
69
    CloseSessionReqInput,
70
    ExpertDistributionReq,
71
    ExpertDistributionReqOutput,
72
73
    FlushCacheReqInput,
    FlushCacheReqOutput,
74
75
    GetInternalStateReq,
    GetInternalStateReqOutput,
76
77
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
78
    HealthCheckOutput,
79
80
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
81
82
    OpenSessionReqInput,
    OpenSessionReqOutput,
83
    ProfileReq,
84
85
    ProfileReqOutput,
    ProfileReqType,
86
87
88
89
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
90
91
    RpcReqInput,
    RpcReqOutput,
92
93
    SetInternalStateReq,
    SetInternalStateReqOutput,
94
95
    SlowDownReqInput,
    SlowDownReqOutput,
96
97
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
98
99
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
100
101
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
102
103
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
104
)
105
from sglang.srt.managers.mm_utils import init_embedding_cache
106
107
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
Mick's avatar
Mick committed
108
    MultimodalInputs,
109
110
    Req,
    ScheduleBatch,
111
    global_server_args_dict,
112
)
113
114
115
116
117
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
118
119
120
from sglang.srt.managers.scheduler_output_processor_mixin import (
    SchedulerOutputProcessorMixin,
)
121
from sglang.srt.managers.session_controller import Session
122
from sglang.srt.managers.tp_worker import TpModelWorker
123
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
124
from sglang.srt.managers.utils import validate_input_length
125
from sglang.srt.mem_cache.chunk_cache import ChunkCache
126
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
127
from sglang.srt.mem_cache.radix_cache import RadixCache
128
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
Lianmin Zheng's avatar
Lianmin Zheng committed
129
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
130
from sglang.srt.reasoning_parser import ReasoningParser
131
from sglang.srt.server_args import PortArgs, ServerArgs
132
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
133
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
134
from sglang.srt.utils import (
135
    DynamicGradMode,
136
137
    broadcast_pyobj,
    configure_logger,
Lianmin Zheng's avatar
Lianmin Zheng committed
138
    disable_request_logging,
139
    get_bool_env_var,
140
    get_zmq_socket,
Lianmin Zheng's avatar
Lianmin Zheng committed
141
    kill_itself_when_parent_died,
142
    point_to_point_pyobj,
143
    pyspy_dump_schedulers,
144
    set_gpu_proc_affinity,
145
146
147
    set_random_seed,
    suppress_other_loggers,
)
148
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
149
150
151

logger = logging.getLogger(__name__)

152
# Test retract decode for debugging purposes
153
154
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
155
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
156

157

158
159
@dataclass
class GenerationBatchResult:
160
161
162
    logits_output: Optional[LogitsProcessorOutput]
    pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
    next_token_ids: Optional[List[int]]
163
164
    extend_input_len_per_req: List[int]
    extend_logprob_start_len_per_req: List[int]
165
    bid: int
166
    can_run_cuda_graph: bool
167
168
169
170
171
172
173
174


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


Byron Hsu's avatar
Byron Hsu committed
175
176
177
178
179
class Scheduler(
    SchedulerOutputProcessorMixin,
    SchedulerDisaggregationDecodeMixin,
    SchedulerDisaggregationPrefillMixin,
):
180
181
182
183
184
185
186
187
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
188
        pp_rank: int,
189
        dp_rank: Optional[int],
190
191
    ):
        # Parse args
192
        self.server_args = server_args
193
        self.tp_rank = tp_rank
194
        self.pp_rank = pp_rank
195
        self.tp_size = server_args.tp_size
196
197
        self.pp_size = server_args.pp_size
        self.dp_size = server_args.dp_size
198
199
200
        self.schedule_policy = server_args.schedule_policy
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
201
        self.enable_overlap = not server_args.disable_overlap_schedule
202
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
203
        self.enable_metrics = server_args.enable_metrics
204
        self.enable_kv_cache_events = server_args.kv_events_config is not None
205
        self.stream_interval = server_args.stream_interval
206
207
208
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
209
210
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
Lianmin Zheng's avatar
Lianmin Zheng committed
211
        self.page_size = server_args.page_size
212
        # Distributed rank info
213
214
        self.dp_size = server_args.dp_size
        self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
215
216
217
218
219
220
221
222
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

223
224
        # Init inter-process communication
        context = zmq.Context(2)
225
        if self.pp_rank == 0 and self.attn_tp_rank == 0:
226
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
227
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
228
            )
229
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
230
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
231
            )
232

233
            if server_args.skip_tokenizer_init:
234
                # Directly send to the TokenizerManager
235
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
236
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
237
238
                )
            else:
239
                # Send to the DetokenizerManager
240
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
241
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
242
                )
243
244
245
246

            self.recv_from_rpc = get_zmq_socket(
                context, zmq.DEALER, port_args.rpc_ipc_name, False
            )
247
        else:
248
            self.recv_from_tokenizer = None
249
            self.recv_from_rpc = None
250
251
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
252
253

        # Init tokenizer
254
        self.init_tokenizer()
255

256
257
258
259
260
261
262
263
264
        # Set reasoning_parser and think_end_id if --reasoning_parser is enabled
        if self.server_args.reasoning_parser and self.tokenizer:
            reasoning_parser = ReasoningParser(
                model_type=self.server_args.reasoning_parser, stream_reasoning=False
            )
            self.tokenizer.think_end_id = self.tokenizer.encode(
                reasoning_parser.detector.think_end_token, add_special_tokens=False
            )[0]

265
266
267
268
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
269

270
        # Launch a tensor parallel worker
271
        if self.enable_overlap:
272
            TpWorkerClass = TpModelWorkerClient
273
274
        else:
            TpWorkerClass = TpModelWorker
275

276
        self.tp_worker = TpWorkerClass(
277
            server_args=server_args,
278
279
            gpu_id=gpu_id,
            tp_rank=tp_rank,
280
            pp_rank=pp_rank,
281
            dp_rank=dp_rank,
282
            nccl_port=port_args.nccl_port,
283
        )
284

285
        # Launch a draft worker for speculative decoding
286
287
288
289
290
291
292
293
294
295
296
297
298
299
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
        else:
            self.draft_worker = None

300
        # Get token and memory info from the model worker
301
302
303
304
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
305
            self.max_req_len,
306
307
            self.max_req_input_len,
            self.random_seed,
308
            self.device,
309
310
311
312
313
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
314
315
316
317
318
319
320
321
        if global_server_args_dict["max_micro_batch_size"] is None:
            global_server_args_dict["max_micro_batch_size"] = max(
                self.max_running_requests // server_args.pp_size, 1
            )

        self.tp_group = self.tp_worker.get_tp_group()
        self.tp_cpu_group = self.tp_group.cpu_group
        self.attn_tp_group = self.tp_worker.get_attention_tp_group()
322
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
323
324
325
        self.pp_group = get_pp_group()
        self.world_group = get_world_group()

326
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
327
        global_server_args_dict.update(worker_global_server_args_dict)
328
        set_random_seed(self.random_seed)
329

330
        # Print debug info
331
332
333
334
335
336
337
338
        if tp_rank == 0:
            logger.info(
                f"max_total_num_tokens={self.max_total_num_tokens}, "
                f"chunked_prefill_size={server_args.chunked_prefill_size}, "
                f"max_prefill_tokens={self.max_prefill_tokens}, "
                f"max_running_requests={self.max_running_requests}, "
                f"context_len={self.model_config.context_len}"
            )
339

Lianmin Zheng's avatar
Lianmin Zheng committed
340
        # Init memory pool and cache
341
        self.init_memory_pool_and_cache()
342
343
344

        # Init running status
        self.waiting_queue: List[Req] = []
345
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
346
        self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
347
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
348
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
349
        # The last forward batch
350
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
351
352
        self.forward_ct = 0
        self.forward_ct_decode = 0
353
        self.num_generated_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
354
        self.num_prefill_tokens = 0
355
356
        self.last_decode_stats_tic = time.perf_counter()
        self.last_prefill_stats_tic = time.perf_counter()
357
        self.return_health_check_ct = 0
358
        self.current_stream = torch.get_device_module(self.device).current_stream()
359
360
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
361

362
        # Init session info
363
        self.sessions: Dict[str, Session] = {}
364
365
366

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
367
368
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
369
        self.chunked_req = None
370
371
372
373
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
374
        # Init the grammar backend for constrained generation
375
        self.grammar_queue: List[Req] = []
376
        if not server_args.skip_tokenizer_init:
377
378
379
            self.grammar_backend = create_grammar_backend(
                server_args, self.tokenizer, self.model_config.vocab_size
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
380
381
        else:
            self.grammar_backend = None
382

383
        # Init schedule policy and new token estimation
384
        self.policy = SchedulePolicy(
Lianmin Zheng's avatar
Lianmin Zheng committed
385
386
387
            self.schedule_policy,
            self.tree_cache,
            self.enable_hierarchical_cache,
388
        )
389
390
391
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
392
393
        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
394
395
            * server_args.schedule_conservativeness,
            1.0,
396
        )
397
398
399
400
401
402
403
404
405
406
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
407
408
409
410
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
411
        self.parent_process = psutil.Process().parent()
Lianmin Zheng's avatar
Lianmin Zheng committed
412

413
        # Init memory saver
414
415
416
417
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )

418
        # Init profiler
419
420
        self.torch_profiler = None
        self.torch_profiler_output_dir: Optional[str] = None
421
        self.profiler_activities: Optional[List[str]] = None
422
        self.profiler_id: Optional[str] = None
423
        self.profiler_target_forward_ct: Optional[int] = None
424

425
426
        self.forward_sleep_time = None

427
        # Init metrics stats
428
        self.init_metrics()
429
        self.init_kv_events(server_args.kv_events_config)
430

431
432
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
433
434
435
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
436
                (FlushCacheReqInput, self.flush_cache_wrapped),
437
                (AbortReq, self.abort_request),
438
439
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
440
441
442
443
444
445
446
447
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
448
449
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
450
                (SlowDownReqInput, self.slow_down),
451
                (ProfileReq, self.profile),
452
                (GetInternalStateReq, self.get_internal_state),
453
                (SetInternalStateReq, self.set_internal_state),
454
                (RpcReqInput, self.handle_rpc_request),
455
                (ExpertDistributionReq, self.expert_distribution_handle),
456
457
458
            ]
        )

Byron Hsu's avatar
Byron Hsu committed
459
460
461
462
463
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
        self.init_disaggregation()

464
465
    def init_tokenizer(self):
        server_args = self.server_args
Lianmin Zheng's avatar
Lianmin Zheng committed
466

467
        self.model_config = ModelConfig.from_server_args(server_args)
468
        self.is_generation = self.model_config.is_generation
469

470
471
472
473
474
475
476
477
478
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if self.model_config.is_multimodal:
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
479
                    use_fast=not server_args.disable_fast_image_processor,
480
                )
xm:D's avatar
xm:D committed
481
                self.tokenizer = get_tokenizer_from_processor(self.processor)
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )

    def init_memory_pool_and_cache(self):
        server_args = self.server_args

        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            self.tp_worker.get_memory_pool()
        )

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
504
                page_size=self.page_size,
505
506
507
508
509
510
            )
        else:
            if self.enable_hierarchical_cache:
                self.tree_cache = HiRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
511
                    tp_cache_group=self.tp_cpu_group,
512
                    page_size=self.page_size,
513
                    hicache_ratio=server_args.hicache_ratio,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
514
515
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
516
517
518
519
520
                )
            else:
                self.tree_cache = RadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
521
                    page_size=self.page_size,
522
                    disable=server_args.disable_radix_cache,
523
                    enable_kv_cache_events=self.enable_kv_cache_events,
524
525
526
527
528
529
530
531
532
533
534
535
                )

        self.decode_mem_cache_buf_multiplier = (
            1
            if self.spec_algorithm.is_none()
            else (
                server_args.speculative_num_draft_tokens
                + (
                    server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                )
            )
536
        )
537
538
539

    def init_metrics(self):
        self.last_gen_throughput: float = 0.0
Lianmin Zheng's avatar
Lianmin Zheng committed
540
        self.last_input_throughput: float = 0.0
541
542
543
544
545
546
547
548
549
550
551
552
553
554
        self.step_time_dict = defaultdict(list)  # Dict[batch size -> step time]
        self.spec_num_total_accepted_tokens = 0
        self.spec_num_total_forward_ct = 0
        self.cum_spec_accept_length = 0
        self.cum_spec_accept_count = 0
        self.stats = SchedulerStats()
        if self.enable_metrics:
            engine_type = "unified"
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    "engine_type": engine_type,
                },
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
555

556
557
558
559
    def init_kv_events(self, kv_events_config: Optional[str]):
        if self.enable_kv_cache_events:
            self.kv_event_publisher = EventPublisherFactory.create(kv_events_config)

Byron Hsu's avatar
Byron Hsu committed
560
    def init_disaggregation(self):
561
562
563
564
        self.transfer_backend = TransferBackend(
            self.server_args.disaggregation_transfer_backend
        )

Byron Hsu's avatar
Byron Hsu committed
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
        if (
            self.disaggregation_mode == DisaggregationMode.DECODE
        ):  # *2 for the headroom.
            buffer_size = (self.req_to_token_pool.size) * 2
            req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
                buffer_size
            )
            aux_dtype = torch.int32
            # A list of metadata buffers. The shape is (b, metadata_size) where
            # b corresponds to a max running requests. The last shape * dtype.itemsize
            # should be larger than 64 bytes to work with RDMA, so we pad it.
            output_id_buffer = torch.zeros(
                (buffer_size, 16), dtype=aux_dtype, device="cpu"
            )
            metadata_buffers = [output_id_buffer]

            # The decode requests polling kv cache
            self.disagg_decode_transfer_queue = DecodeTransferQueue(
583
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
584
585
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
                metadata_buffers=metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
586
587
                scheduler=self,
                tree_cache=self.tree_cache,
Byron Hsu's avatar
Byron Hsu committed
588
589
590
591
592
593
594
595
596
597
598
599
            )

            # The decode requests pending for pre-allocation
            self.disagg_decode_prealloc_queue = DecodePreallocQueue(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
                metadata_buffers=metadata_buffers,
                aux_dtype=aux_dtype,
                scheduler=self,
                transfer_queue=self.disagg_decode_transfer_queue,
                tree_cache=self.tree_cache,
600
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
601
602
603
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
604
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
605
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
606
607
608
609

            # Metric for pre-allocation
            self.num_tokens_pre_allocated = 0

Byron Hsu's avatar
Byron Hsu committed
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
        elif self.disaggregation_mode == DisaggregationMode.PREFILL:
            # *2 for the headroom.
            buffer_size = self.max_running_requests * 2
            req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
                buffer_size
            )
            aux_dtype = torch.int32
            # A list of metadata buffers. The shape is (b, metadata_size) where
            # b corresponds to a max running requests. The last shape * dtype.itemsize
            # should be larger than 64 bytes to work with RDMA, so we pad it.
            output_id_buffer = torch.zeros(
                (buffer_size, 16), dtype=aux_dtype, device="cpu"
            )
            metadata_buffers = [output_id_buffer]

Liangsheng Yin's avatar
Liangsheng Yin committed
625
            self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
Byron Hsu's avatar
Byron Hsu committed
626
627
628
629
630
631
632
                token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
                metadata_buffers=metadata_buffers,
                aux_dtype=aux_dtype,
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
633
                gloo_group=self.attn_tp_cpu_group,
634
                transfer_backend=self.transfer_backend,
635
                scheduler=self,
Byron Hsu's avatar
Byron Hsu committed
636
637
            )
            # The prefill requests that are in the middle of kv sending
638
            self.disagg_prefill_inflight_queue: List[Req] = []
Byron Hsu's avatar
Byron Hsu committed
639

640
    @DynamicGradMode()
641
    def event_loop_normal(self):
642
        """A normal scheduler loop."""
643
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
644
645
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
646

647
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
648
            self.cur_batch = batch
649
650
651
652

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
653
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
654
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
655
                self.check_memory()
656
                self.new_token_ratio = self.init_new_token_ratio
657
658

            self.last_batch = batch
659

660
    @DynamicGradMode()
Lianmin Zheng's avatar
Lianmin Zheng committed
661
    def event_loop_overlap(self):
662
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
663
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
664
665
666
667
668
669
670

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
671

Lianmin Zheng's avatar
Lianmin Zheng committed
672
            if batch:
673
                batch.launch_done = threading.Event()
Lianmin Zheng's avatar
Lianmin Zheng committed
674
                result = self.run_batch(batch)
675
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
676

677
                if self.last_batch is None:
678
                    # Create a dummy first batch to start the pipeline for overlap schedule.
679
680
681
682
683
684
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
685
                    self.process_batch_result(tmp_batch, None, batch.launch_done)
686

Lianmin Zheng's avatar
Lianmin Zheng committed
687
            if self.last_batch:
688
                # Process the results of the last batch
689
                tmp_batch, tmp_result = self.result_queue.popleft()
690
691
692
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
693
694
695
696
                # NOTE: we should use current launched batch's launch_done event Instead of the last batch's
                self.process_batch_result(
                    tmp_batch, tmp_result, batch.launch_done if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
697
            elif batch is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
698
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
699
                self.check_memory()
700
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
701
702
703

            self.last_batch = batch

704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
    @DynamicGradMode()
    def event_loop_pp(self):
        """A non-overlap scheduler loop for pipeline parallelism."""
        mbs = [None] * self.pp_size
        last_mbs = [None] * self.pp_size
        self.running_mbs = [
            ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
        ]
        bids = [None] * self.pp_size
        pp_outputs: Optional[PPProxyTensors] = None
        while True:
            server_is_idle = True
            for mb_id in range(self.pp_size):
                self.running_batch = self.running_mbs[mb_id]
                self.last_batch = last_mbs[mb_id]

                recv_reqs = self.recv_requests()
                self.process_input_requests(recv_reqs)
                mbs[mb_id] = self.get_next_batch_to_run()
                self.running_mbs[mb_id] = self.running_batch

                self.cur_batch = mbs[mb_id]
                if self.cur_batch:
                    server_is_idle = False
                    result = self.run_batch(self.cur_batch)

730
                # (last rank) send the outputs to the next step
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
                if self.pp_group.is_last_rank:
                    if self.cur_batch:
                        next_token_ids, bids[mb_id] = (
                            result.next_token_ids,
                            result.bid,
                        )
                        pp_outputs = PPProxyTensors(
                            {
                                "next_token_ids": next_token_ids,
                            }
                        )
                        # send the output from the last round to let the next stage worker run post processing
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                # receive outputs and post-process (filter finished reqs) the coming microbatch
                next_mb_id = (mb_id + 1) % self.pp_size
                next_pp_outputs = None
                if mbs[next_mb_id] is not None:
                    next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
                        self.pp_group.recv_tensor_dict(
                            all_gather_group=self.attn_tp_group
                        )
                    )
                    mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
                    output_result = GenerationBatchResult(
                        logits_output=None,
                        pp_hidden_states_proxy_tensors=None,
                        next_token_ids=next_pp_outputs["next_token_ids"],
                        extend_input_len_per_req=None,
                        extend_logprob_start_len_per_req=None,
                        bid=bids[next_mb_id],
765
                        can_run_cuda_graph=result.can_run_cuda_graph,
766
767
768
769
                    )
                    self.process_batch_result(mbs[next_mb_id], output_result)
                    last_mbs[next_mb_id] = mbs[next_mb_id]

770
                # (not last rank)
771
772
773
                if not self.pp_group.is_last_rank:
                    if self.cur_batch:
                        bids[mb_id] = result.bid
774
775
                    # carry the outputs to the next stage
                    # send the outputs from the last round to let the next stage worker run post processing
776
777
778
779
780
781
782
                    if pp_outputs:
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                    # send out reqs to the next stage
783
                    dp_offset = self.attn_dp_rank * self.attn_tp_size
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
                    if self.attn_tp_rank == 0:
                        point_to_point_pyobj(
                            recv_reqs,
                            self.pp_rank * self.tp_size + dp_offset,
                            self.world_group.cpu_group,
                            self.pp_rank * self.tp_size + dp_offset,
                            (self.pp_rank + 1) * self.tp_size + dp_offset,
                        )

                    # send out proxy tensors to the next stage
                    if self.cur_batch:
                        self.pp_group.send_tensor_dict(
                            result.pp_hidden_states_proxy_tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                pp_outputs = next_pp_outputs

            # When the server is idle, self-check and re-init some states
            if server_is_idle:
                self.check_memory()
                self.new_token_ratio = self.init_new_token_ratio

807
808
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
        if self.pp_rank == 0:
            if self.attn_tp_rank == 0:
                recv_reqs = []

                while True:
                    try:
                        recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_req)

                while True:
                    try:
                        recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_rpc)
            else:
                recv_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
828
        else:
829
            if self.attn_tp_rank == 0:
830
                dp_offset = self.attn_dp_rank * self.attn_tp_size
831
832
833
834
835
836
837
838
839
                recv_reqs = point_to_point_pyobj(
                    [],
                    self.pp_rank * self.tp_size + dp_offset,
                    self.world_group.cpu_group,
                    (self.pp_rank - 1) * self.tp_size + dp_offset,
                    self.pp_rank * self.tp_size + dp_offset,
                )
            else:
                recv_reqs = None
840

841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                work_reqs = broadcast_pyobj(
                    work_reqs,
864
                    self.attn_tp_group.rank,
865
                    self.attn_tp_cpu_group,
866
                    src=self.attn_tp_group.ranks[0],
867
868
869
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
870
871
872
873
                    control_reqs,
                    self.tp_group.rank,
                    self.tp_cpu_group,
                    src=self.tp_group.ranks[0],
874
875
876
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
877
878
879
880
881
882
            recv_reqs = broadcast_pyobj(
                recv_reqs,
                self.tp_group.rank,
                self.tp_cpu_group,
                src=self.tp_group.ranks[0],
            )
883
884
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
885
    def process_input_requests(self, recv_reqs: List):
886
        for recv_req in recv_reqs:
887
888
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
Lianmin Zheng's avatar
Lianmin Zheng committed
889
                self.chunked_req is not None or not self.running_batch.is_empty()
890
891
892
893
            ):
                self.return_health_check_ct += 1
                continue

894
            output = self._request_dispatcher(recv_req)
895
            if output is not None:
896
897
898
899
900
                if isinstance(output, RpcReqOutput):
                    if self.recv_from_rpc is not None:
                        self.recv_from_rpc.send_pyobj(output)
                else:
                    self.send_to_tokenizer.send_pyobj(output)
901
902
903
904
905

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
906
        # Create a new request
907
908
909
910
911
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
912
913
914
915
916
917
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

918
919
920
921
            if recv_req.bootstrap_port is None:
                # Use default bootstrap port
                recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port

922
923
924
925
926
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
927
928
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
929
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
930
                stream=recv_req.stream,
931
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
932
                input_embeds=recv_req.input_embeds,
Lianmin Zheng's avatar
Lianmin Zheng committed
933
                custom_logit_processor=recv_req.custom_logit_processor,
934
                return_hidden_states=recv_req.return_hidden_states,
935
                eos_token_ids=self.model_config.hf_eos_token_id,
936
                bootstrap_host=recv_req.bootstrap_host,
937
                bootstrap_port=recv_req.bootstrap_port,
938
                bootstrap_room=recv_req.bootstrap_room,
939
940
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
941

942
943
944
945
946
947
948
949
950
951
952
953
            if self.disaggregation_mode != DisaggregationMode.NULL:
                # Invalid request for disaggregated mode
                if recv_req.bootstrap_room is None:
                    error_message = (
                        f"Invalid request: Disaggregated request received without "
                        f"boostrap room id. {req.rid=}"
                    )
                    logger.error(error_message)
                    prepare_abort(req, error_message)
                    self.stream_output([req], req.return_logprob)
                    return

954
955
956
957
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
958
                req.finished_reason = FINISH_ABORT(
959
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
960
                )
961
                self._add_request_to_queue(req)
962
963
                return
        else:
964
965
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
966
            req = session.create_req(recv_req, self.tokenizer)
967
            if isinstance(req.finished_reason, FINISH_ABORT):
968
                self._add_request_to_queue(req)
969
                return
970

971
        # Handle multimodal inputs
Mick's avatar
Mick committed
972
973
        if recv_req.mm_inputs is not None:
            image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
974
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
975
            req.origin_input_ids = self.pad_input_ids_func(
976
                req.origin_input_ids, image_inputs
977
            )
978
            req.extend_image_inputs(image_inputs)
979

980
            if len(req.origin_input_ids) >= self.max_req_input_len:
981
                error_msg = (
982
                    "Multimodal prompt is too long after expanding multimodal tokens. "
983
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
984
                )
985
                logger.error(error_msg)
986
                req.origin_input_ids = [0]
Mick's avatar
Mick committed
987
                req.multimodal_inputs = None
988
                req.sampling_params.max_new_tokens = 0
989
                req.finished_reason = FINISH_ABORT(
990
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
991
                )
992
                self._add_request_to_queue(req)
993
994
                return

995
996
997
998
999
1000
1001
        # Validate prompts length
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
1002
1003
            req.origin_input_ids = [0]
            req.sampling_params.max_new_tokens = 0
1004
            self._add_request_to_queue(req)
1005
            return
1006

1007
        # Copy more attributes
1008
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
1009
1010
1011
1012
1013
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(req.origin_input_ids) - 1
        else:
            req.logprob_start_len = recv_req.logprob_start_len

1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
        if req.logprob_start_len >= len(req.origin_input_ids):
            req.finished_reason = FINISH_ABORT(
                f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
                HTTPStatus.BAD_REQUEST,
                "BadRequestError",
            )
            req.logprob_start_len = len(req.origin_input_ids) - 1
            self._add_request_to_queue(req)
            return

1024
1025
1026
1027
1028
1029
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
1030
            self.max_req_len - len(req.origin_input_ids) - 1,
1031
1032
        )

1033
1034
1035
1036
1037
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
1038
            or req.sampling_params.ebnf is not None
1039
            or req.sampling_params.structural_tag is not None
1040
1041
1042
1043
1044
1045
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
1046
1047
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
1048
1049
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
1050

1051
1052
1053
1054
1055
            value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
            req.grammar = value

            if not cache_hit:
                req.grammar_key = key
1056
1057
1058
                add_to_grammar_queue = True

        if add_to_grammar_queue:
1059
            req.queue_time_start = time.perf_counter()
1060
1061
            self.grammar_queue.append(req)
        else:
1062
1063
1064
            self._add_request_to_queue(req)

    def _add_request_to_queue(self, req: Req):
1065
        req.queue_time_start = time.perf_counter()
Byron Hsu's avatar
Byron Hsu committed
1066
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
Liangsheng Yin's avatar
Liangsheng Yin committed
1067
            self.disagg_prefill_bootstrap_queue.add(req)
Byron Hsu's avatar
Byron Hsu committed
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.add(req)
        else:
            self.waiting_queue.append(req)

    def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
        if self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.extend(reqs)
        else:
            self.waiting_queue.extend(reqs)
1078
1079
1080

    def handle_embedding_request(
        self,
1081
        recv_req: TokenizedEmbeddingReqInput,
1082
1083
1084
1085
1086
1087
1088
1089
1090
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

1091
1092
        # Handle multimodal inputs
        if recv_req.image_inputs is not None:
Mick's avatar
Mick committed
1093
            image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
            req.origin_input_ids = self.pad_input_ids_func(
                req.origin_input_ids, image_inputs
            )
            req.extend_image_inputs(image_inputs)

            if len(req.origin_input_ids) >= self.max_req_input_len:
                error_msg = (
                    "Multimodal prompt is too long after expanding multimodal tokens. "
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                )
                logger.error(error_msg)
                req.origin_input_ids = [0]
Mick's avatar
Mick committed
1107
                req.multimodal_inputs = None
1108
1109
1110
1111
                req.sampling_params.max_new_tokens = 0
                req.finished_reason = FINISH_ABORT(
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
                )
1112
                req.queue_time_start = time.perf_counter()
1113
1114
1115
                self.waiting_queue.append(req)
                return

1116
        # Validate prompts length
1117
        error_msg = validate_input_length(
1118
1119
1120
1121
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
1122
        if error_msg:
1123
            self._add_request_to_queue(req)
1124
            return
1125

1126
1127
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
1128
        self._add_request_to_queue(req)
1129

1130
1131
1132
1133
    def log_prefill_stats(
        self,
        adder: PrefillAdder,
        can_run_list: List[Req],
1134
        running_bs: int,
1135
    ):
1136
1137
        gap_latency = time.perf_counter() - self.last_prefill_stats_tic
        self.last_prefill_stats_tic = time.perf_counter()
Lianmin Zheng's avatar
Lianmin Zheng committed
1138
1139
1140
        self.last_input_throughput = self.num_prefill_tokens / gap_latency
        self.num_prefill_tokens = 0

1141
        num_used = self.max_total_num_tokens - (
1142
1143
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
1144
1145
        )

1146
        num_new_seq = len(can_run_list)
1147
        f = (
1148
            f"Prefill batch. "
1149
            f"#new-seq: {num_new_seq}, "
1150
1151
1152
1153
1154
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
1155
1156
1157
1158
1159
1160
1161
1162

        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
            f += f"#queue-req: {len(self.waiting_queue)}, "
            f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)} "
        else:
            f += f"#queue-req: {len(self.waiting_queue)}"

1163
        logger.info(f)
1164
1165

        if self.enable_metrics:
1166
1167
1168
            cache_hit_rate = adder.log_hit_tokens / (
                adder.log_input_tokens + adder.log_hit_tokens
            )
1169
1170
1171
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
1172
1173
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.stats.cache_hit_rate = cache_hit_rate
1174
1175
1176
1177
1178
1179

            total_queue_latency = 0
            for req in can_run_list:
                total_queue_latency += req.queue_time_end - req.queue_time_start
            self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq

1180
            self.metrics_collector.log_stats(self.stats)
1181
        self._publish_kv_events()
1182

1183
1184
1185
    def log_decode_stats(
        self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
    ):
1186
1187
        batch = running_batch or self.running_batch

1188
1189
        gap_latency = time.perf_counter() - self.last_decode_stats_tic
        self.last_decode_stats_tic = time.perf_counter()
1190
1191
        self.last_gen_throughput = self.num_generated_tokens / gap_latency
        self.num_generated_tokens = 0
1192
        num_running_reqs = len(batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1193
        num_used = self.max_total_num_tokens - (
1194
1195
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1196
        )
1197
1198
1199
1200
1201

        if RECORD_STEP_TIME:
            self.step_time_dict[num_running_reqs].append(
                gap_latency / self.server_args.decode_log_interval
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1202

Liangsheng Yin's avatar
Liangsheng Yin committed
1203
1204
1205
1206
1207
1208
1209
        msg = (
            f"Decode batch. "
            f"#running-req: {num_running_reqs}, "
            f"#token: {num_used}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
        )

1210
        if self.spec_algorithm.is_none():
1211
            spec_accept_length = 0
1212
        else:
1213
            spec_accept_length = (
1214
1215
                self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
            )
1216
1217
            self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
            self.cum_spec_accept_count += self.spec_num_total_forward_ct
1218
            self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
1219
1220
1221
1222
1223
1224
            msg += f"accept len: {spec_accept_length:.2f}, "

        if self.disaggregation_mode == DisaggregationMode.DECODE:
            msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "

        msg += (
1225
            f"cuda graph: {can_run_cuda_graph}, "
Liangsheng Yin's avatar
Liangsheng Yin committed
1226
1227
1228
            f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
            f"#queue-req: {len(self.waiting_queue)}"
        )
1229
1230

        logger.info(msg)
1231
1232
1233
1234
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
1235
1236
            self.stats.cache_hit_rate = 0.0
            self.stats.gen_throughput = self.last_gen_throughput
1237
            self.stats.num_queue_reqs = len(self.waiting_queue)
1238
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1239
            self.stats.spec_accept_length = spec_accept_length
1240
            self.metrics_collector.log_stats(self.stats)
1241
        self._publish_kv_events()
1242

Lianmin Zheng's avatar
Lianmin Zheng committed
1243
1244
    def check_memory(self):
        available_size = (
1245
1246
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1247
        )
1248
1249
1250
1251
1252
1253
1254
        protected_size = self.tree_cache.protected_size()
        memory_leak = available_size != (
            self.max_total_num_tokens
            if not self.enable_hierarchical_cache
            else self.max_total_num_tokens - protected_size
        )
        if memory_leak:
1255
            msg = (
1256
                "token_to_kv_pool_allocator memory leak detected! "
1257
                f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1258
1259
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1260
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1261
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1262
1263

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
1264
            msg = (
1265
                "req_to_token_pool memory leak detected!"
1266
1267
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1268
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1269
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1270

1271
1272
1273
        if (
            self.enable_metrics
            and self.attn_tp_rank == 0
1274
            and time.perf_counter() > self.metrics_collector.last_log_time + 30
1275
1276
1277
        ):
            # During idle time, also collect metrics every 30 seconds.
            num_used = self.max_total_num_tokens - (
1278
                self.token_to_kv_pool_allocator.available_size()
1279
1280
                + self.tree_cache.evictable_size()
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1281
            num_running_reqs = len(self.running_batch.reqs)
1282
1283
1284
1285
1286
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
1287
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1288
            self.metrics_collector.log_stats(self.stats)
1289
        self._publish_kv_events()
1290

1291
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1292
        # Merge the prefill batch into the running batch
1293
1294
1295
1296
1297
1298
1299
1300
        chunked_req_to_exclude = set()
        if self.chunked_req:
            # Move the chunked request out of the batch so that we can merge
            # only finished requests to running_batch.
            chunked_req_to_exclude.add(self.chunked_req)
            self.tree_cache.cache_unfinished_req(self.chunked_req)
            # chunked request keeps its rid but will get a new req_pool_idx
            self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
Lianmin Zheng's avatar
Lianmin Zheng committed
1301
        if self.last_batch and self.last_batch.forward_mode.is_extend():
1302
1303
1304
1305
            if self.last_batch.chunked_req is not None:
                # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
                # We need to discard it.
                chunked_req_to_exclude.add(self.last_batch.chunked_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1306

1307
            # Filter batch
1308
            last_bs = self.last_batch.batch_size()
1309
1310
1311
            self.last_batch.filter_batch(
                chunked_req_to_exclude=list(chunked_req_to_exclude)
            )
1312
            if self.last_batch.batch_size() < last_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1313
                self.running_batch.batch_is_full = False
1314

1315
            # Merge the new batch into the running batch
1316
            if not self.last_batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1317
                if self.running_batch.is_empty():
1318
1319
                    self.running_batch = self.last_batch
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1320
                    # Merge running_batch with prefill batch
1321
                    self.running_batch.merge_batch(self.last_batch)
1322

1323
1324
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
1325
1326
1327
1328
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
1329
            if not self.running_batch.is_empty():
1330
                self.running_batch = self.update_running_batch(self.running_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1331
1332
1333
                ret = self.running_batch if not self.running_batch.is_empty() else None
            else:
                ret = None
1334

1335
        # Handle DP attention
1336
        if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
Lianmin Zheng's avatar
Lianmin Zheng committed
1337
            ret, _ = self.prepare_dp_attn_batch(ret)
1338
1339

        return ret
1340

1341
1342
1343
1344
1345
1346
    def get_num_allocatable_reqs(self, running_bs):
        res = global_server_args_dict["max_micro_batch_size"] - running_bs
        if self.pp_size > 1:
            res = min(res, self.req_to_token_pool.available_size())
        return res

Lianmin Zheng's avatar
Lianmin Zheng committed
1347
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
1348
        # Check if the grammar is ready in the grammar queue
1349
        if self.grammar_queue:
1350
            self.move_ready_grammar_requests()
1351

Lianmin Zheng's avatar
Lianmin Zheng committed
1352
1353
        # Handle the cases where prefill is not allowed
        if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1354
            self.running_batch.batch_is_full or len(self.waiting_queue) == 0
1355
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1356
1357
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1358
        running_bs = len(self.running_batch.reqs)
1359
        # Ignore the check if self.chunked_req is not None.
1360
1361
1362
1363
1364
        # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
        # as the space for the chunked request has just been released.
        # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
        # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
        if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
Lianmin Zheng's avatar
Lianmin Zheng committed
1365
            self.running_batch.batch_is_full = True
1366
1367
            return None

1368
1369
1370
1371
1372
        if self.enable_hierarchical_cache:
            # check for completion of hierarchical cache activities to release memory
            self.tree_cache.writing_check()
            self.tree_cache.loading_check()

1373
1374
1375
        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
1376
        # Prefill policy
1377
1378
        adder = PrefillAdder(
            self.tree_cache,
1379
            self.token_to_kv_pool_allocator,
1380
1381
1382
1383
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1384
            running_bs if self.is_mixed_chunk else 0,
1385
1386
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1387
        if self.chunked_req is not None:
1388
1389
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1390

Lianmin Zheng's avatar
Lianmin Zheng committed
1391
        if self.lora_paths:
Lianmin Zheng's avatar
Lianmin Zheng committed
1392
1393
            lora_set = set([req.lora_path for req in self.running_batch.reqs])

1394
        # Get requests from the waiting queue to a new prefill batch
1395
1396
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1397
                self.lora_paths
1398
1399
1400
1401
1402
1403
1404
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1405
                self.running_batch.batch_is_full = True
1406
1407
                break

1408
            if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
Lianmin Zheng's avatar
Lianmin Zheng committed
1409
                self.running_batch.batch_is_full = True
1410
                break
1411

1412
1413
1414
1415
            req.init_next_round_input(
                None if prefix_computed else self.tree_cache,
                self.enable_hierarchical_cache,
            )
1416

1417
1418
1419
            res = adder.add_one_req(
                req, self.chunked_req, self.enable_hierarchical_cache
            )
1420

1421
1422
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1423
1424
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
Lianmin Zheng's avatar
Lianmin Zheng committed
1425
1426
                        self.running_batch.batch_is_full = len(
                            adder.can_run_list
1427
                        ) > 0 or (not self.running_batch.is_empty())
1428
                    else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1429
                        self.running_batch.batch_is_full = True
1430
1431
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1432
        # Update waiting queue
1433
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1434
1435
        if len(can_run_list) == 0:
            return None
1436
1437
1438
1439

        if self.enable_metrics:
            # only record queue time when enable_metrics is True to avoid overhead
            for req in can_run_list:
1440
                req.queue_time_end = time.perf_counter()
1441

Lianmin Zheng's avatar
Lianmin Zheng committed
1442
1443
1444
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1445

1446
        if self.enable_hierarchical_cache:
1447
            self.tree_cache.ready_to_load_cache()
1448

1449
1450
1451
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1452

1453
1454
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1455

1456
        # Print stats
1457
        if self.attn_tp_rank == 0:
1458
            self.log_prefill_stats(adder, can_run_list, running_bs)
1459

Lianmin Zheng's avatar
Lianmin Zheng committed
1460
        # Create a new batch
1461
1462
1463
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
1464
            self.token_to_kv_pool_allocator,
1465
            self.tree_cache,
1466
            self.model_config,
1467
            self.enable_overlap,
1468
            self.spec_algorithm,
1469
            self.server_args.enable_custom_logit_processor,
1470
            chunked_req=self.chunked_req,
1471
        )
1472
        new_batch.prepare_for_extend()
1473

Lianmin Zheng's avatar
Lianmin Zheng committed
1474
        # Mixed-style chunked prefill
1475
1476
        if (
            self.is_mixed_chunk
Lianmin Zheng's avatar
Lianmin Zheng committed
1477
            and not self.running_batch.is_empty()
1478
1479
1480
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1481
1482
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1483
                self.running_batch.prepare_for_decode()
1484
1485
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1486
1487
1488
            self.running_batch = ScheduleBatch(
                reqs=[], batch_is_full=self.running_batch.batch_is_full
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1489
1490
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1491
1492
1493

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1494
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1495
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1496
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1497

1498
1499
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1500
1501
            batch.batch_is_full = False
            return batch
1502

Lianmin Zheng's avatar
Lianmin Zheng committed
1503
        # Check if decode out of memory
1504
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1505
            TEST_RETRACT and batch.batch_size() > 10
1506
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1507
1508
            old_ratio = self.new_token_ratio

1509
            retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
Lianmin Zheng's avatar
Lianmin Zheng committed
1510
            self.new_token_ratio = new_token_ratio
1511

Lianmin Zheng's avatar
Lianmin Zheng committed
1512
1513
1514
1515
1516
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
1517
            self._extend_requests_to_queue(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1518
1519
        else:
            self.new_token_ratio = max(
1520
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1521
1522
1523
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1524
        if batch.batch_size() < initial_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1525
            batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1526
1527

        # Update batch tensors
1528
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1529
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1530

1531
1532
1533
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1534
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1535
1536
        self.forward_ct += 1

1537
1538
1539
1540
1541
        # Check profiler
        if (
            self.profiler_target_forward_ct
            and self.profiler_target_forward_ct <= self.forward_ct
        ):
1542
            self.send_to_tokenizer.send_pyobj(self.stop_profile())
1543

1544
1545
1546
1547
        if self.forward_sleep_time is not None:
            logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
            time.sleep(self.forward_sleep_time)

1548
        # Run forward
1549
        if self.is_generation:
1550
1551
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
1552
                if self.pp_group.is_last_rank:
1553
                    logits_output, next_token_ids, can_run_cuda_graph = (
1554
1555
1556
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
                else:
1557
                    pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
1558
1559
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
1560
                bid = model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
1561
            else:
1562
1563
1564
                (
                    logits_output,
                    next_token_ids,
1565
                    bid,
1566
                    num_accepted_tokens,
1567
                    can_run_cuda_graph,
1568
1569
1570
1571
1572
1573
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
                self.spec_num_total_accepted_tokens += (
                    num_accepted_tokens + batch.batch_size()
                )
                self.spec_num_total_forward_ct += batch.batch_size()
                self.num_generated_tokens += num_accepted_tokens
1574
1575
1576

            if self.pp_group.is_last_rank:
                batch.output_ids = next_token_ids
1577

1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
            if batch.return_logprob:
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_input_len_per_req = None
                extend_logprob_start_len_per_req = None

1590
            ret = GenerationBatchResult(
1591
1592
1593
1594
1595
1596
1597
                logits_output=logits_output if self.pp_group.is_last_rank else None,
                pp_hidden_states_proxy_tensors=(
                    pp_hidden_states_proxy_tensors
                    if not self.pp_group.is_last_rank
                    else None
                ),
                next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
1598
1599
                extend_input_len_per_req=extend_input_len_per_req,
                extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1600
                bid=bid,
1601
                can_run_cuda_graph=can_run_cuda_graph,
1602
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1603
1604
1605
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1606
1607
1608
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1609
        return ret
Chayenne's avatar
Chayenne committed
1610

1611
1612
1613
1614
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
1615
        launch_done: Optional[threading.Event] = None,
1616
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1617
        if batch.forward_mode.is_decode():
1618
            self.process_batch_result_decode(batch, result, launch_done)
1619
        elif batch.forward_mode.is_extend():
1620
            self.process_batch_result_prefill(batch, result, launch_done)
1621
1622
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1623
                self.tp_worker.resolve_last_batch_result(launch_done)
1624
                self.set_next_batch_sampling_info_done(batch)
1625
        elif batch.forward_mode.is_dummy_first():
1626
            self.set_next_batch_sampling_info_done(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1627

1628
1629
1630
1631
1632
1633
1634
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

1635
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
1636
1637
1638
1639
        return self.prepare_dp_attn_batch_raw(
            local_batch,
            dp_size=self.server_args.dp_size,
            attn_tp_size=self.attn_tp_size,
1640
            moe_dense_tp_size=self.server_args.moe_dense_tp_size,
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
            tp_cpu_group=self.tp_cpu_group,
            get_idle_batch=self.get_idle_batch,
            disable_cuda_graph=self.server_args.disable_cuda_graph,
            spec_algorithm=self.spec_algorithm,
            speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
        )

    @staticmethod
    def prepare_dp_attn_batch_raw(
        local_batch: ScheduleBatch,
        dp_size,
        attn_tp_size: int,
1653
        moe_dense_tp_size: Optional[int],
1654
1655
1656
1657
1658
1659
        tp_cpu_group,
        get_idle_batch,
        disable_cuda_graph: bool,
        spec_algorithm,
        speculative_num_draft_tokens,
    ):
1660
1661
1662
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
1663
            num_tokens_for_logprob = 0
1664
1665
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
1666
1667
            if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
                num_tokens = num_tokens * speculative_num_draft_tokens
1668
            num_tokens_for_logprob = num_tokens
1669
1670
        else:
            num_tokens = local_batch.extend_num_tokens
1671
            num_tokens_for_logprob = sum(
Lianmin Zheng's avatar
Lianmin Zheng committed
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
                [
                    # We should have at least 1 token for sample in every case.
                    max(extend_len - logprob_start_len, 1)
                    for logprob_start_len, extend_len in zip(
                        local_batch.extend_logprob_start_lens, local_batch.extend_lens
                    )
                ]
            )

        if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
            can_cuda_graph = 1
        else:
            can_cuda_graph = 0

1686
        if not spec_algorithm.is_none():
1687
            # TODO(sang): Support cuda graph when idle batch is there.
Lianmin Zheng's avatar
Lianmin Zheng committed
1688
1689
            if local_batch is None or local_batch.forward_mode.is_idle():
                can_cuda_graph = 0
1690

Lianmin Zheng's avatar
Lianmin Zheng committed
1691
1692
1693
1694
1695
1696
1697
        is_extend_in_batch = (
            local_batch.forward_mode.is_extend() if local_batch else False
        )
        local_info = torch.tensor(
            [
                num_tokens,
                can_cuda_graph,
1698
                num_tokens_for_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
1699
1700
1701
1702
1703
                is_extend_in_batch,
            ],
            dtype=torch.int64,
        )
        global_info = torch.empty(
1704
            (dp_size, attn_tp_size, 4),
Lianmin Zheng's avatar
Lianmin Zheng committed
1705
1706
            dtype=torch.int64,
        )
1707
        torch.distributed.all_gather_into_tensor(
Lianmin Zheng's avatar
Lianmin Zheng committed
1708
1709
            global_info.flatten(),
            local_info,
1710
            group=tp_cpu_group,
1711
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1712
1713
1714
1715
        global_num_tokens = global_info[:, 0, 0].tolist()
        can_cuda_graph = min(global_info[:, 0, 1].tolist())
        global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
        is_extend_in_batch = global_info[:, 0, 3].tolist()
1716

Lianmin Zheng's avatar
Lianmin Zheng committed
1717
        if local_batch is None and max(global_num_tokens) > 0:
1718
            local_batch = get_idle_batch()
1719
1720

        if local_batch is not None:
1721
1722
1723
1724
1725
1726
1727
1728
1729
            # TODO: handle the case when moe_dense_tp_size != 1
            if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
                local_batch.global_num_tokens = [num_tokens]
                local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
            else:
                local_batch.global_num_tokens = global_num_tokens
                local_batch.global_num_tokens_for_logprob = (
                    global_num_tokens_for_logprob
                )
1730

1731
            # Check forward mode for cuda graph
1732
            if not disable_cuda_graph:
Lianmin Zheng's avatar
Lianmin Zheng committed
1733
                local_batch.can_run_dp_cuda_graph = can_cuda_graph
1734

Lianmin Zheng's avatar
Lianmin Zheng committed
1735
        return local_batch, any(is_extend_in_batch)
1736
1737
1738
1739
1740

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
1741
            self.token_to_kv_pool_allocator,
1742
1743
1744
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
1745
            self.spec_algorithm,
1746
            self.server_args.enable_custom_logit_processor,
1747
1748
1749
1750
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1751
1752
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
1753

1754
        num_ready_reqs = 0
1755
        num_abort_reqs = 0
1756
1757
        for req in self.grammar_queue:
            try:
1758
1759
1760
                req.grammar = req.grammar.result(timeout=0.03)
                if req.grammar:
                    self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
1761
1762
                num_ready_reqs += 1
            except futures._base.TimeoutError:
1763
1764
1765
                req.grammar_wait_ct += 1
                if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
                    num_abort_reqs = 1
1766
1767
                break

1768
        if self.server_args.enable_dp_attention:
1769
1770
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
1771
        else:
1772
1773
1774
1775
1776
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
1777
            tensor = torch.tensor([num_ready_reqs, num_abort_reqs], dtype=torch.int32)
1778
1779
1780
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
1781
1782
            num_ready_reqs_max, num_abort_reqs_max = tensor.tolist()

1783
            for i in range(num_ready_reqs, num_ready_reqs_max):
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
                req = self.grammar_queue[i]
                req.grammar = req.grammar.result()
                if req.grammar:
                    self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())

            for i in range(num_ready_reqs, num_ready_reqs + num_abort_reqs_max):
                req = self.grammar_queue[i]
                req.grammar.cancel()
                req.grammar = None
                error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
                logger.error(error_msg)
                req.finished_reason = FINISH_ABORT(
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
                )
            num_ready_reqs = num_ready_reqs_max + num_abort_reqs_max
1799

1800
        self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1801
1802
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1803
1804
1805
1806
1807
1808
1809
    def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
        if batch.next_batch_sampling_info:
            if batch.next_batch_sampling_info.grammars is not None:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
                self.current_stream.synchronize()
            batch.next_batch_sampling_info.sampling_info_done.set()

1810
1811
1812
    def watchdog_thread(self):
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
        self.watchdog_last_forward_ct = 0
1813
        self.watchdog_last_time = time.perf_counter()
1814
1815

        while True:
1816
            current = time.perf_counter()
1817
1818
1819
1820
1821
1822
1823
1824
1825
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if current > self.watchdog_last_time + self.watchdog_timeout:
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)

Lianmin Zheng's avatar
Lianmin Zheng committed
1826
1827
1828
1829
1830
1831
1832
1833
1834
        if not disable_request_logging():
            # Print batch size and memory pool info to check whether there are de-sync issues.
            logger.error(
                f"{self.cur_batch.batch_size()=}, "
                f"{self.cur_batch.reqs=}, "
                f"{self.token_to_kv_pool_allocator.available_size()=}, "
                f"{self.tree_cache.evictable_size()=}, "
            )

1835
        pyspy_dump_schedulers()
Lianmin Zheng's avatar
Lianmin Zheng committed
1836
        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
1837
1838
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
1839
1840

        # Wait for some time so that the parent process can print the error.
1841
1842
1843
        time.sleep(5)
        self.parent_process.send_signal(signal.SIGQUIT)

1844
1845
1846
    def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
        success = self.flush_cache()
        return FlushCacheReqOutput(success=success)
1847

1848
    def flush_cache(self):
1849
        """Flush the memory pool and cache."""
1850
1851
1852
1853
1854
        if (
            len(self.waiting_queue) == 0
            and self.running_batch.is_empty()
            and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
        ):
1855
1856
            self.cur_batch = None
            self.last_batch = None
1857
            self.tree_cache.reset()
1858
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1859
                self.grammar_backend.reset()
1860
            self.req_to_token_pool.clear()
1861
            self.token_to_kv_pool_allocator.clear()
1862
1863
1864

            if not self.spec_algorithm.is_none():
                self.draft_worker.model_runner.req_to_token_pool.clear()
1865
                self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
1866
1867
1868
1869
1870

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
1871
1872
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
1873
1874
1875
1876
1877
1878
1879
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
1880
                f"#running-req: {len(self.running_batch.reqs)}"
1881
1882
1883
1884
            )
            if_success = False
        return if_success

1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )
        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
        return GetInternalStateReqOutput(
            internal_state=ret,
        )

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
1902
                "max_micro_batch_size",
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
1913
1914
1915
1916
1917
1918
1919
1920
            elif k == "max_micro_batch_size" and (
                v > self.max_running_requests // self.pp_size or v < 1
            ):
                logging.warning(
                    f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
                )
                if_success = False
                break
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
            logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
    def handle_rpc_request(self, recv_req: RpcReqInput):
        # Handle RPC requests
        logger.info(
            f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
        )

        success = True
        exec = None
        try:
            func = getattr(self, recv_req.method)
            func(recv_req.parameters)
        except Exception as e:
            success = False
            exec = e
            logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")

        barrier()
        return RpcReqOutput(success, "" if not exec else str(exec))

    def save_remote_model(self, params):
        url = params["url"]

1958
        worker = self.tp_worker.worker
1959
1960
1961
1962

        worker.model_runner.save_remote_model(url)

    def save_sharded_model(self, params):
1963
        worker = self.tp_worker.worker
1964
1965
1966
1967
1968
1969
1970

        worker.model_runner.save_sharded_model(
            path=params["path"],
            pattern=params["pattern"],
            max_size=params["max_size"],
        )

1971
    def abort_request(self, recv_req: AbortReq):
Lianmin Zheng's avatar
Lianmin Zheng committed
1972
1973
        # TODO(lmzheng): abort the requests in the grammar queue.

1974
        # Delete requests in the waiting queue
Lianmin Zheng's avatar
Lianmin Zheng committed
1975
        to_del = []
1976
        for i, req in enumerate(self.waiting_queue):
Lianmin Zheng's avatar
Lianmin Zheng committed
1977
1978
            if req.rid.startswith(recv_req.rid):
                to_del.append(i)
1979

Lianmin Zheng's avatar
Lianmin Zheng committed
1980
        # Sort in reverse order to avoid index issues when deleting
Lianmin Zheng's avatar
Lianmin Zheng committed
1981
        for i in reversed(to_del):
Lianmin Zheng's avatar
Lianmin Zheng committed
1982
            req = self.waiting_queue.pop(i)
Lianmin Zheng's avatar
Lianmin Zheng committed
1983
            self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
1984
            logger.debug(f"Abort queued request. {req.rid=}")
1985
1986

        # Delete requests in the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1987
1988
1989
1990
1991
1992
        if self.cur_batch is self.running_batch or self.cur_batch is None:
            reqs = self.running_batch.reqs
        else:
            reqs = self.running_batch.reqs + self.cur_batch.reqs

        for req in reqs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1993
1994
1995
            if req.rid.startswith(recv_req.rid) and not req.finished():
                logger.debug(f"Abort running request. {req.rid=}")
                req.to_abort = True
1996

1997
1998
1999
    def _pause_engine(self) -> Tuple[List[Req], int]:
        raise NotImplementedError()

Chayenne's avatar
Chayenne committed
2000
2001
2002
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
2003
2004
2005
2006
2007
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
2008
        return UpdateWeightFromDiskReqOutput(success, message, 0)
2009

2010
2011
2012
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
2013
        return InitWeightsUpdateGroupReqOutput(success, message)
2014
2015

    def update_weights_from_distributed(
2016
2017
2018
        self,
        recv_req: UpdateWeightsFromDistributedReqInput,
    ) -> Tuple[bool, str]:
2019
2020
2021
2022
2023
2024
2025
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
2026
        return UpdateWeightsFromDistributedReqOutput(success, message)
2027

2028
2029
2030
2031
2032
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        """Update the online model parameter from tensors."""
        success, message = self.tp_worker.update_weights_from_tensor(recv_req)
        # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
        if success:
2033
2034
2035
            if recv_req.flush_cache:
                flash_cache_success = self.flush_cache()
                assert flash_cache_success, "Cache flush failed after updating weights"
2036
2037
        else:
            logger.error(message)
2038
        return UpdateWeightsFromTensorReqOutput(success, message)
2039

2040
2041
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
2042
        return GetWeightsByNameReqOutput(parameter)
2043

2044
    def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
2045
2046
2047
        self.memory_saver_adapter.check_validity(
            caller_name="release_memory_occupation"
        )
2048
2049
2050
2051
2052
        self.stashed_model_static_state = _export_static_state(
            self.tp_worker.worker.model_runner.model
        )
        self.memory_saver_adapter.pause()
        self.flush_cache()
2053
        return ReleaseMemoryOccupationReqOutput()
2054

2055
    def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
2056
        self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
2057
2058
2059
2060
2061
        self.memory_saver_adapter.resume()
        _import_static_state(
            self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
        )
        del self.stashed_model_static_state
2062
2063
        return ResumeMemoryOccupationReqOutput()

2064
2065
2066
2067
2068
2069
2070
    def slow_down(self, recv_req: SlowDownReqInput):
        t = recv_req.forward_sleep_time
        if t is not None and t <= 0:
            t = None
        self.forward_sleep_time = t
        return SlowDownReqOutput()

2071
    def profile(self, recv_req: ProfileReq):
2072
2073
        if recv_req.type == ProfileReqType.START_PROFILE:
            return self.start_profile(
2074
2075
2076
2077
2078
                recv_req.output_dir,
                recv_req.num_steps,
                recv_req.activities,
                recv_req.with_stack,
                recv_req.record_shapes,
2079
                recv_req.profile_id,
2080
            )
2081
        else:
2082
2083
2084
2085
2086
2087
2088
            return self.stop_profile()

    def start_profile(
        self,
        output_dir: Optional[str],
        num_steps: Optional[int],
        activities: Optional[List[str]],
2089
2090
        with_stack: Optional[bool],
        record_shapes: Optional[bool],
2091
        profile_id: Optional[str],
2092
    ) -> None:
2093
        if self.profiler_activities:
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
            return ProfileReqOutput(
                success=False,
                message="Profiling is already in progress. Call /stop_profile first.",
            )

        if output_dir is None:
            output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
        if activities is None:
            activities = ["CPU", "GPU"]

        self.torch_profiler_output_dir = output_dir
2105
        self.profiler_activities = activities
2106
        self.profiler_id = profile_id
2107
        logger.info(
2108
            "Profiling starts. Traces will be saved to: %s (with id %s)",
2109
            self.torch_profiler_output_dir,
2110
            self.profiler_id,
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
        )

        activity_map = {
            "CPU": torch.profiler.ProfilerActivity.CPU,
            "GPU": torch.profiler.ProfilerActivity.CUDA,
        }
        torchprof_activities = [
            activity_map[a] for a in activities if a in activity_map
        ]

        if torchprof_activities:
            self.torch_profiler = torch.profiler.profile(
                activities=torchprof_activities,
2124
2125
                with_stack=with_stack if with_stack is not None else True,
                record_shapes=record_shapes if record_shapes is not None else False,
2126
2127
2128
2129
2130
            )
            self.torch_profiler.start()

        if "MEM" in activities:
            torch.cuda.memory._record_memory_history(max_entries=100000)
2131

2132
2133
2134
        if "CUDA_PROFILER" in activities:
            torch.cuda.cudart().cudaProfilerStart()

2135
2136
2137
2138
2139
2140
        if num_steps:
            self.profiler_target_forward_ct = self.forward_ct + num_steps
            # The caller will be notified when reaching profiler_target_forward_ct
        else:
            self.profiler_target_forward_ct = None
            return ProfileReqOutput(success=True, message="Succeeded")
2141
2142

    def stop_profile(self) -> None:
2143
        if self.profiler_activities is None:
2144
2145
2146
2147
            return ProfileReqOutput(
                success=False,
                message="Profiling is not in progress. Call /start_profile first.",
            )
2148
2149
2150
2151
2152
2153
2154

        logger.info("Stop profiling...")
        if self.torch_profiler is not None:
            self.torch_profiler.stop()
            self.torch_profiler.export_chrome_trace(
                os.path.join(
                    self.torch_profiler_output_dir,
2155
                    self.profiler_id + f"-TP-{self.tp_rank}" + ".trace.json.gz",
2156
2157
2158
                )
            )

2159
        if "MEM" in self.profiler_activities:
2160
            memory_profile_path = os.path.join(
2161
                self.torch_profiler_output_dir,
2162
                self.profiler_id + f"-TP-{self.tp_rank}-memory" + ".pickle",
2163
2164
2165
2166
            )
            torch.cuda.memory._dump_snapshot(memory_profile_path)
            torch.cuda.memory._record_memory_history(enabled=None)

2167
2168
2169
        if "CUDA_PROFILER" in self.profiler_activities:
            torch.cuda.cudart().cudaProfilerStop()

2170
2171
2172
        logger.info(
            "Profiling done. Traces are saved to: %s",
            self.torch_profiler_output_dir,
2173
        )
2174
2175
        self.torch_profiler = None
        self.torch_profiler_output_dir = None
2176
        self.profiler_activities = None
2177

2178
        return ProfileReqOutput(success=True, message="Succeeded")
2179

2180
2181
    def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
        if recv_req == ExpertDistributionReq.START_RECORD:
2182
            get_global_expert_distribution_recorder().start_record()
2183
        elif recv_req == ExpertDistributionReq.STOP_RECORD:
2184
            get_global_expert_distribution_recorder().stop_record()
2185
        elif recv_req == ExpertDistributionReq.DUMP_RECORD:
2186
            get_global_expert_distribution_recorder().dump_record()
2187
2188
        else:
            raise ValueError("Unrecognized ExpertDistributionReq value")
2189
        return ExpertDistributionReqOutput()
2190

2191
    def open_session(self, recv_req: OpenSessionReqInput):
2192
2193
2194
2195
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
2196
            return OpenSessionReqOutput(session_id, False)
2197
        elif session_id is None:
2198
            logger.warning("session id is None, cannot open.")
2199
            return OpenSessionReqOutput(session_id, False)
2200
2201
2202
2203
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
2204
            return OpenSessionReqOutput(session_id, True)
2205
2206
2207
2208
2209
2210
2211
2212
2213

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

2214
2215
    def get_print_prefix(self):
        prefix = ""
2216
2217
        if self.attn_dp_rank is not None:
            prefix += f" DP{self.attn_dp_rank}"
2218
2219
2220
2221
2222
2223
        if self.server_args.tp_size > 1:
            prefix += f" TP{self.tp_rank}"
        if self.pp_size > 1:
            prefix += f" PP{self.pp_rank}"
        return prefix

2224
2225
2226
2227
2228
2229
2230
    def _publish_kv_events(self):
        if self.enable_kv_cache_events:
            events = self.tree_cache.take_events()
            if events:
                batch = KVEventBatch(ts=time.time(), events=events)
                self.kv_event_publisher.publish(batch)

2231

2232
2233
2234
2235
def is_health_check_generate_req(recv_req):
    return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")


2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
def _export_static_state(model):
    return dict(
        buffers=[
            (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
        ]
    )


def _import_static_state(model, static_params):
    self_named_buffers = dict(model.named_buffers())
    for name, tensor in static_params["buffers"]:
        self_named_buffers[name][...] = tensor


2250
2251
2252
2253
2254
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
2255
    pp_rank: int,
2256
    dp_rank: Optional[int],
2257
    pipe_writer,
2258
):
2259
    # Generate the prefix
2260
2261
2262
2263
2264
2265
2266
    prefix = ""
    if dp_rank is not None:
        prefix += f" DP{dp_rank}"
    if server_args.tp_size > 1:
        prefix += f" TP{tp_rank}"
    if server_args.pp_size > 1:
        prefix += f" PP{pp_rank}"
2267

2268
    # Config the process
2269
    kill_itself_when_parent_died()
2270
    setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
2271
    faulthandler.enable()
2272
    parent_process = psutil.Process().parent()
2273

2274
2275
2276
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
2277

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
2278
    # Configure the logger
2279
    configure_logger(server_args, prefix=prefix)
2280
    suppress_other_loggers()
2281

2282
    # Set cpu affinity to this gpu process
2283
2284
2285
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

2286
2287
2288
2289
    embedding_cache_size = 100
    if "SGLANG_VLM_CACHE_SIZE_MB" in os.environ:
        embedding_cache_size = int(os.environ["SGLANG_VLM_CACHE_SIZE_MB"])
    init_embedding_cache(embedding_cache_size * 1024 * 1024)
2290
    # Create a scheduler and run the event loop
2291
    try:
2292
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
2293
        pipe_writer.send(
Mick's avatar
Mick committed
2294
2295
2296
2297
2298
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2299
        )
Byron Hsu's avatar
Byron Hsu committed
2300
2301
2302
        disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode

        if disaggregation_mode == DisaggregationMode.NULL:
2303
2304
2305
            if server_args.pp_size > 1:
                scheduler.event_loop_pp()
            elif scheduler.enable_overlap:
Byron Hsu's avatar
Byron Hsu committed
2306
2307
2308
2309
                scheduler.event_loop_overlap()
            else:
                scheduler.event_loop_normal()
        elif disaggregation_mode == DisaggregationMode.PREFILL:
2310
2311
2312
2313
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_prefill()
            else:
                scheduler.event_loop_normal_disagg_prefill()
2314

Byron Hsu's avatar
Byron Hsu committed
2315
        elif disaggregation_mode == DisaggregationMode.DECODE:
2316
2317
2318
2319
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_decode()
            else:
                scheduler.event_loop_normal_disagg_decode()
Byron Hsu's avatar
Byron Hsu committed
2320

2321
    except Exception:
2322
2323
2324
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)