scheduler.py 88.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
23
import time
import warnings
24
from collections import defaultdict, deque
Lianmin Zheng's avatar
Lianmin Zheng committed
25
from concurrent import futures
26
from dataclasses import dataclass
27
from http import HTTPStatus
28
from types import SimpleNamespace
29
from typing import Dict, List, Optional, Tuple, Union
30

31
import psutil
32
import setproctitle
33
import torch
34
import zmq
35
from torch.distributed import barrier
36

37
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
38
from sglang.srt.configs.model_config import ModelConfig
39
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
Byron Hsu's avatar
Byron Hsu committed
40
41
42
43
44
45
46
47
48
49
50
51
from sglang.srt.disaggregation.decode import (
    DecodePreallocQueue,
    DecodeTransferQueue,
    SchedulerDisaggregationDecodeMixin,
)
from sglang.srt.disaggregation.prefill import (
    PrefillBootstrapQueue,
    SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
    ReqToMetadataIdxAllocator,
52
    TransferBackend,
Byron Hsu's avatar
Byron Hsu committed
53
)
54
from sglang.srt.distributed import get_pp_group, get_world_group
xm:D's avatar
xm:D committed
55
56
57
58
59
from sglang.srt.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
60
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
61
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
62
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
63
64
from sglang.srt.managers.io_struct import (
    AbortReq,
65
    CloseSessionReqInput,
66
    ExpertDistributionReq,
67
    ExpertDistributionReqOutput,
68
69
    FlushCacheReqInput,
    FlushCacheReqOutput,
70
71
    GetInternalStateReq,
    GetInternalStateReqOutput,
72
73
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
74
    HealthCheckOutput,
75
76
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
77
78
    OpenSessionReqInput,
    OpenSessionReqOutput,
79
    ProfileReq,
80
81
    ProfileReqOutput,
    ProfileReqType,
82
83
84
85
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
86
87
    RpcReqInput,
    RpcReqOutput,
88
89
    SetInternalStateReq,
    SetInternalStateReqOutput,
90
91
    SlowDownReqInput,
    SlowDownReqOutput,
92
93
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
94
95
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
96
97
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
98
99
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
100
101
102
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
Mick's avatar
Mick committed
103
    MultimodalInputs,
104
105
    Req,
    ScheduleBatch,
106
    global_server_args_dict,
107
)
108
109
110
111
112
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
113
114
115
from sglang.srt.managers.scheduler_output_processor_mixin import (
    SchedulerOutputProcessorMixin,
)
116
from sglang.srt.managers.session_controller import Session
117
from sglang.srt.managers.tp_worker import TpModelWorker
118
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
119
from sglang.srt.managers.utils import validate_input_length
120
from sglang.srt.mem_cache.chunk_cache import ChunkCache
121
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
122
from sglang.srt.mem_cache.radix_cache import RadixCache
123
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
124
125
126
127
128
from sglang.srt.model_executor.forward_batch_info import (
    ForwardBatch,
    ForwardMode,
    PPProxyTensors,
)
129
from sglang.srt.reasoning_parser import ReasoningParser
130
from sglang.srt.server_args import PortArgs, ServerArgs
131
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
132
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
133
from sglang.srt.utils import (
134
    DynamicGradMode,
135
136
    broadcast_pyobj,
    configure_logger,
137
    crash_on_warnings,
138
    get_bool_env_var,
139
    get_zmq_socket,
Lianmin Zheng's avatar
Lianmin Zheng committed
140
    kill_itself_when_parent_died,
141
    point_to_point_pyobj,
142
    pyspy_dump_schedulers,
143
    set_gpu_proc_affinity,
144
145
146
    set_random_seed,
    suppress_other_loggers,
)
147
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
148

149
150
expert_distribution_recorder = ExpertDistributionRecorder()

151
152
logger = logging.getLogger(__name__)

153
# Test retract decode for debugging purposes
154
155
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
156

157

158
159
@dataclass
class GenerationBatchResult:
160
161
162
    logits_output: Optional[LogitsProcessorOutput]
    pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
    next_token_ids: Optional[List[int]]
163
164
    extend_input_len_per_req: List[int]
    extend_logprob_start_len_per_req: List[int]
165
166
167
168
169
170
171
172
173
    bid: int


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


Byron Hsu's avatar
Byron Hsu committed
174
175
176
177
178
class Scheduler(
    SchedulerOutputProcessorMixin,
    SchedulerDisaggregationDecodeMixin,
    SchedulerDisaggregationPrefillMixin,
):
179
180
181
182
183
184
185
186
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
187
        pp_rank: int,
188
        dp_rank: Optional[int],
189
190
    ):
        # Parse args
191
        self.server_args = server_args
192
        self.tp_rank = tp_rank
193
        self.pp_rank = pp_rank
194
        self.tp_size = server_args.tp_size
195
196
        self.pp_size = server_args.pp_size
        self.dp_size = server_args.dp_size
197
198
199
        self.schedule_policy = server_args.schedule_policy
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
200
        self.enable_overlap = not server_args.disable_overlap_schedule
201
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
202
        self.enable_metrics = server_args.enable_metrics
203
        self.stream_interval = server_args.stream_interval
204
205
206
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
207
208
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
Lianmin Zheng's avatar
Lianmin Zheng committed
209
        self.page_size = server_args.page_size
210

211
        # Distributed rank info
212
213
214
215
216
217
218
219
220
        self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

221
222
        # Init inter-process communication
        context = zmq.Context(2)
223
        if self.pp_rank == 0 and self.attn_tp_rank == 0:
224
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
225
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
226
            )
227
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
228
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
229
            )
230

231
            if server_args.skip_tokenizer_init:
232
                # Directly send to the TokenizerManager
233
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
234
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
235
236
                )
            else:
237
                # Send to the DetokenizerManager
238
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
239
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
240
                )
241
242
243
244

            self.recv_from_rpc = get_zmq_socket(
                context, zmq.DEALER, port_args.rpc_ipc_name, False
            )
245
        else:
246
            self.recv_from_tokenizer = None
247
            self.recv_from_rpc = None
248
249
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
250
251

        # Init tokenizer
252
        self.init_tokenizer()
253

254
255
256
257
258
259
260
261
262
        # Set reasoning_parser and think_end_id if --reasoning_parser is enabled
        if self.server_args.reasoning_parser and self.tokenizer:
            reasoning_parser = ReasoningParser(
                model_type=self.server_args.reasoning_parser, stream_reasoning=False
            )
            self.tokenizer.think_end_id = self.tokenizer.encode(
                reasoning_parser.detector.think_end_token, add_special_tokens=False
            )[0]

263
264
265
266
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
267

268
        # Launch a tensor parallel worker
269
        if self.enable_overlap:
270
            TpWorkerClass = TpModelWorkerClient
271
272
        else:
            TpWorkerClass = TpModelWorker
273

274
        self.tp_worker = TpWorkerClass(
275
            server_args=server_args,
276
277
            gpu_id=gpu_id,
            tp_rank=tp_rank,
278
            pp_rank=pp_rank,
279
            dp_rank=dp_rank,
280
            nccl_port=port_args.nccl_port,
281
        )
282

283
        # Launch a draft worker for speculative decoding
284
285
286
287
288
289
290
291
292
293
294
295
296
297
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
        else:
            self.draft_worker = None

298
        # Get token and memory info from the model worker
299
300
301
302
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
303
            self.max_req_len,
304
305
            self.max_req_input_len,
            self.random_seed,
306
            self.device,
307
308
309
310
311
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
312
313
314
315
316
317
318
319
        if global_server_args_dict["max_micro_batch_size"] is None:
            global_server_args_dict["max_micro_batch_size"] = max(
                self.max_running_requests // server_args.pp_size, 1
            )

        self.tp_group = self.tp_worker.get_tp_group()
        self.tp_cpu_group = self.tp_group.cpu_group
        self.attn_tp_group = self.tp_worker.get_attention_tp_group()
320
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
321
322
323
        self.pp_group = get_pp_group()
        self.world_group = get_world_group()

324
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
325
        global_server_args_dict.update(worker_global_server_args_dict)
326
        set_random_seed(self.random_seed)
327

328
329
330
        # Print debug info
        logger.info(
            f"max_total_num_tokens={self.max_total_num_tokens}, "
331
            f"chunked_prefill_size={server_args.chunked_prefill_size}, "
332
333
334
335
336
            f"max_prefill_tokens={self.max_prefill_tokens}, "
            f"max_running_requests={self.max_running_requests}, "
            f"context_len={self.model_config.context_len}"
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
337
        # Init memory pool and cache
338
        self.init_memory_pool_and_cache()
339
340
341

        # Init running status
        self.waiting_queue: List[Req] = []
342
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
343
        self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
344
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
345
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
346
        # The last forward batch
347
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
348
349
        self.forward_ct = 0
        self.forward_ct_decode = 0
350
        self.num_generated_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
351
        self.num_prefill_tokens = 0
352
        self.last_decode_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
353
        self.last_prefill_stats_tic = time.time()
354
        self.return_health_check_ct = 0
355
        self.current_stream = torch.get_device_module(self.device).current_stream()
356
357
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
358

359
        # Init session info
360
        self.sessions: Dict[str, Session] = {}
361
362
363

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
364
365
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
366
        self.chunked_req = None
367
368
369
370
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
371
        # Init the grammar backend for constrained generation
372
        self.grammar_queue: List[Req] = []
373
        if not server_args.skip_tokenizer_init:
374
375
376
            self.grammar_backend = create_grammar_backend(
                server_args, self.tokenizer, self.model_config.vocab_size
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
377
378
        else:
            self.grammar_backend = None
379

380
        # Init schedule policy and new token estimation
381
        self.policy = SchedulePolicy(
Lianmin Zheng's avatar
Lianmin Zheng committed
382
383
384
            self.schedule_policy,
            self.tree_cache,
            self.enable_hierarchical_cache,
385
        )
386
387
388
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
389
390
        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
391
392
            * server_args.schedule_conservativeness,
            1.0,
393
        )
394
395
396
397
398
399
400
401
402
403
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
404
405
406
407
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
408
        self.parent_process = psutil.Process().parent()
Lianmin Zheng's avatar
Lianmin Zheng committed
409

410
        # Init memory saver
411
412
413
414
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )

415
        # Init profiler
416
417
        self.torch_profiler = None
        self.torch_profiler_output_dir: Optional[str] = None
418
        self.profiler_activities: Optional[List[str]] = None
419
        self.profiler_id: Optional[str] = None
420
        self.profiler_target_forward_ct: Optional[int] = None
421

422
423
        self.forward_sleep_time = None

424
        # Init metrics stats
425
        self.init_metrics()
426

427
428
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
429
430
431
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
432
                (FlushCacheReqInput, self.flush_cache_wrapped),
433
                (AbortReq, self.abort_request),
434
435
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
436
437
438
439
440
441
442
443
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
444
445
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
446
                (SlowDownReqInput, self.slow_down),
447
                (ProfileReq, self.profile),
448
                (GetInternalStateReq, self.get_internal_state),
449
                (SetInternalStateReq, self.set_internal_state),
450
                (RpcReqInput, self.handle_rpc_request),
451
                (ExpertDistributionReq, self.expert_distribution_handle),
452
453
454
            ]
        )

Byron Hsu's avatar
Byron Hsu committed
455
456
457
458
459
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
        self.init_disaggregation()

460
461
    def init_tokenizer(self):
        server_args = self.server_args
Lianmin Zheng's avatar
Lianmin Zheng committed
462

463
        self.model_config = ModelConfig.from_server_args(server_args)
464
        self.is_generation = self.model_config.is_generation
465

466
467
468
469
470
471
472
473
474
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if self.model_config.is_multimodal:
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
475
                    use_fast=not server_args.disable_fast_image_processor,
476
                )
xm:D's avatar
xm:D committed
477
                self.tokenizer = get_tokenizer_from_processor(self.processor)
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )

    def init_memory_pool_and_cache(self):
        server_args = self.server_args

        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            self.tp_worker.get_memory_pool()
        )

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
500
                page_size=self.page_size,
501
502
503
504
505
506
            )
        else:
            if self.enable_hierarchical_cache:
                self.tree_cache = HiRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
507
                    tp_cache_group=self.tp_cpu_group,
508
                    page_size=self.page_size,
509
                    hicache_ratio=server_args.hicache_ratio,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
510
511
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
512
513
514
515
516
                )
            else:
                self.tree_cache = RadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
517
                    page_size=self.page_size,
518
519
520
521
522
523
524
525
526
527
528
529
530
                    disable=server_args.disable_radix_cache,
                )

        self.decode_mem_cache_buf_multiplier = (
            1
            if self.spec_algorithm.is_none()
            else (
                server_args.speculative_num_draft_tokens
                + (
                    server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                )
            )
531
        )
532
533
534
535
536
537
538

    def init_metrics(self):
        # The largest prefill length of a single request
        self._largest_prefill_len: int = 0
        # The largest context length (prefill + generation) of a single request
        self._largest_prefill_decode_len: int = 0
        self.last_gen_throughput: float = 0.0
Lianmin Zheng's avatar
Lianmin Zheng committed
539
        self.last_input_throughput: float = 0.0
540
541
542
543
544
545
546
547
548
549
550
551
552
553
        self.step_time_dict = defaultdict(list)  # Dict[batch size -> step time]
        self.spec_num_total_accepted_tokens = 0
        self.spec_num_total_forward_ct = 0
        self.cum_spec_accept_length = 0
        self.cum_spec_accept_count = 0
        self.stats = SchedulerStats()
        if self.enable_metrics:
            engine_type = "unified"
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    "engine_type": engine_type,
                },
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
554

Byron Hsu's avatar
Byron Hsu committed
555
    def init_disaggregation(self):
556
557
558
559
        self.transfer_backend = TransferBackend(
            self.server_args.disaggregation_transfer_backend
        )

Byron Hsu's avatar
Byron Hsu committed
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
        if (
            self.disaggregation_mode == DisaggregationMode.DECODE
        ):  # *2 for the headroom.
            buffer_size = (self.req_to_token_pool.size) * 2
            req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
                buffer_size
            )
            aux_dtype = torch.int32
            # A list of metadata buffers. The shape is (b, metadata_size) where
            # b corresponds to a max running requests. The last shape * dtype.itemsize
            # should be larger than 64 bytes to work with RDMA, so we pad it.
            output_id_buffer = torch.zeros(
                (buffer_size, 16), dtype=aux_dtype, device="cpu"
            )
            metadata_buffers = [output_id_buffer]

            # The decode requests polling kv cache
            self.disagg_decode_transfer_queue = DecodeTransferQueue(
578
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
579
580
581
582
583
584
585
586
587
588
589
590
591
592
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
                metadata_buffers=metadata_buffers,
            )

            # The decode requests pending for pre-allocation
            self.disagg_decode_prealloc_queue = DecodePreallocQueue(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
                metadata_buffers=metadata_buffers,
                aux_dtype=aux_dtype,
                scheduler=self,
                transfer_queue=self.disagg_decode_transfer_queue,
                tree_cache=self.tree_cache,
593
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
594
595
596
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
597
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
598
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
599
600
601
602

            # Metric for pre-allocation
            self.num_tokens_pre_allocated = 0

Byron Hsu's avatar
Byron Hsu committed
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
        elif self.disaggregation_mode == DisaggregationMode.PREFILL:
            # *2 for the headroom.
            buffer_size = self.max_running_requests * 2
            req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
                buffer_size
            )
            aux_dtype = torch.int32
            # A list of metadata buffers. The shape is (b, metadata_size) where
            # b corresponds to a max running requests. The last shape * dtype.itemsize
            # should be larger than 64 bytes to work with RDMA, so we pad it.
            output_id_buffer = torch.zeros(
                (buffer_size, 16), dtype=aux_dtype, device="cpu"
            )
            metadata_buffers = [output_id_buffer]

Liangsheng Yin's avatar
Liangsheng Yin committed
618
            self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
Byron Hsu's avatar
Byron Hsu committed
619
620
621
622
623
624
625
                token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
                metadata_buffers=metadata_buffers,
                aux_dtype=aux_dtype,
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
626
                gloo_group=self.attn_tp_cpu_group,
627
                transfer_backend=self.transfer_backend,
628
                scheduler=self,
Byron Hsu's avatar
Byron Hsu committed
629
630
            )
            # The prefill requests that are in the middle of kv sending
631
            self.disagg_prefill_inflight_queue: List[Req] = []
Byron Hsu's avatar
Byron Hsu committed
632

633
    @DynamicGradMode()
634
    def event_loop_normal(self):
635
        """A normal scheduler loop."""
636
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
637
638
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
639

640
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
641
            self.cur_batch = batch
642
643
644
645

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
646
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
647
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
648
                self.check_memory()
649
                self.new_token_ratio = self.init_new_token_ratio
650
651

            self.last_batch = batch
652

653
    @DynamicGradMode()
Lianmin Zheng's avatar
Lianmin Zheng committed
654
    def event_loop_overlap(self):
655
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
656
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
657
658
659
660
661
662
663

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
664

Lianmin Zheng's avatar
Lianmin Zheng committed
665
            if batch:
666
                batch.launch_done = threading.Event()
Lianmin Zheng's avatar
Lianmin Zheng committed
667
                result = self.run_batch(batch)
668
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
669

670
                if self.last_batch is None:
671
                    # Create a dummy first batch to start the pipeline for overlap schedule.
672
673
674
675
676
677
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
678
                    self.process_batch_result(tmp_batch, None, batch.launch_done)
679

Lianmin Zheng's avatar
Lianmin Zheng committed
680
            if self.last_batch:
681
                # Process the results of the last batch
682
                tmp_batch, tmp_result = self.result_queue.popleft()
683
684
685
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
686
687
688
689
                # NOTE: we should use current launched batch's launch_done event Instead of the last batch's
                self.process_batch_result(
                    tmp_batch, tmp_result, batch.launch_done if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
690
            elif batch is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
691
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
692
                self.check_memory()
693
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
694
695
696

            self.last_batch = batch

697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
    @DynamicGradMode()
    def event_loop_pp(self):
        """A non-overlap scheduler loop for pipeline parallelism."""
        mbs = [None] * self.pp_size
        last_mbs = [None] * self.pp_size
        self.running_mbs = [
            ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
        ]
        bids = [None] * self.pp_size
        pp_outputs: Optional[PPProxyTensors] = None
        while True:
            server_is_idle = True
            for mb_id in range(self.pp_size):
                self.running_batch = self.running_mbs[mb_id]
                self.last_batch = last_mbs[mb_id]

                recv_reqs = self.recv_requests()
                self.process_input_requests(recv_reqs)
                mbs[mb_id] = self.get_next_batch_to_run()
                self.running_mbs[mb_id] = self.running_batch

                self.cur_batch = mbs[mb_id]
                if self.cur_batch:
                    server_is_idle = False
                    result = self.run_batch(self.cur_batch)

                # send the outputs to the next step
                if self.pp_group.is_last_rank:
                    if self.cur_batch:
                        next_token_ids, bids[mb_id] = (
                            result.next_token_ids,
                            result.bid,
                        )
                        pp_outputs = PPProxyTensors(
                            {
                                "next_token_ids": next_token_ids,
                            }
                        )
                        # send the output from the last round to let the next stage worker run post processing
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                # receive outputs and post-process (filter finished reqs) the coming microbatch
                next_mb_id = (mb_id + 1) % self.pp_size
                next_pp_outputs = None
                if mbs[next_mb_id] is not None:
                    next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
                        self.pp_group.recv_tensor_dict(
                            all_gather_group=self.attn_tp_group
                        )
                    )
                    mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
                    output_result = GenerationBatchResult(
                        logits_output=None,
                        pp_hidden_states_proxy_tensors=None,
                        next_token_ids=next_pp_outputs["next_token_ids"],
                        extend_input_len_per_req=None,
                        extend_logprob_start_len_per_req=None,
                        bid=bids[next_mb_id],
                    )
                    self.process_batch_result(mbs[next_mb_id], output_result)
                    last_mbs[next_mb_id] = mbs[next_mb_id]

                # carry the outputs to the next stage
                if not self.pp_group.is_last_rank:
                    if self.cur_batch:
                        bids[mb_id] = result.bid
                    if pp_outputs:
                        # send the outputs from the last round to let the next stage worker run post processing
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                if not self.pp_group.is_last_rank:
                    # send out reqs to the next stage
                    dp_offset = self.dp_rank * self.attn_tp_size
                    if self.attn_tp_rank == 0:
                        point_to_point_pyobj(
                            recv_reqs,
                            self.pp_rank * self.tp_size + dp_offset,
                            self.world_group.cpu_group,
                            self.pp_rank * self.tp_size + dp_offset,
                            (self.pp_rank + 1) * self.tp_size + dp_offset,
                        )

                    # send out proxy tensors to the next stage
                    if self.cur_batch:
                        self.pp_group.send_tensor_dict(
                            result.pp_hidden_states_proxy_tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                pp_outputs = next_pp_outputs

            # When the server is idle, self-check and re-init some states
            if server_is_idle:
                self.check_memory()
                self.new_token_ratio = self.init_new_token_ratio

799
800
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
        if self.pp_rank == 0:
            if self.attn_tp_rank == 0:
                recv_reqs = []

                while True:
                    try:
                        recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_req)

                while True:
                    try:
                        recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_rpc)
            else:
                recv_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
820
        else:
821
822
823
824
825
826
827
828
829
830
831
            if self.attn_tp_rank == 0:
                dp_offset = self.dp_rank * self.attn_tp_size
                recv_reqs = point_to_point_pyobj(
                    [],
                    self.pp_rank * self.tp_size + dp_offset,
                    self.world_group.cpu_group,
                    (self.pp_rank - 1) * self.tp_size + dp_offset,
                    self.pp_rank * self.tp_size + dp_offset,
                )
            else:
                recv_reqs = None
832

833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                work_reqs = broadcast_pyobj(
                    work_reqs,
856
                    self.attn_tp_group.rank,
857
                    self.attn_tp_cpu_group,
858
                    src=self.attn_tp_group.ranks[0],
859
860
861
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
862
863
864
865
                    control_reqs,
                    self.tp_group.rank,
                    self.tp_cpu_group,
                    src=self.tp_group.ranks[0],
866
867
868
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
869
870
871
872
873
874
            recv_reqs = broadcast_pyobj(
                recv_reqs,
                self.tp_group.rank,
                self.tp_cpu_group,
                src=self.tp_group.ranks[0],
            )
875
876
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
877
    def process_input_requests(self, recv_reqs: List):
878
        for recv_req in recv_reqs:
879
880
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
Lianmin Zheng's avatar
Lianmin Zheng committed
881
                self.chunked_req is not None or not self.running_batch.is_empty()
882
883
884
885
            ):
                self.return_health_check_ct += 1
                continue

886
            output = self._request_dispatcher(recv_req)
887
            if output is not None:
888
889
890
891
892
                if isinstance(output, RpcReqOutput):
                    if self.recv_from_rpc is not None:
                        self.recv_from_rpc.send_pyobj(output)
                else:
                    self.send_to_tokenizer.send_pyobj(output)
893
894
895
896
897

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
898
        # Create a new request
899
900
901
902
903
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
904
905
906
907
908
909
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

910
911
912
913
914
915
916
917
918
919
920
921
922
            # Handle custom logit processor passed to the request
            custom_logit_processor = recv_req.custom_logit_processor
            if (
                not self.server_args.enable_custom_logit_processor
                and custom_logit_processor is not None
            ):
                logger.warning(
                    "The SGLang server is not configured to enable custom logit processor."
                    "The custom logit processor passed in will be ignored."
                    "Please set --enable-custom-logits-processor to enable this feature."
                )
                custom_logit_processor = None

923
924
925
926
            if recv_req.bootstrap_port is None:
                # Use default bootstrap port
                recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port

927
928
929
930
931
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
932
933
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
934
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
935
                stream=recv_req.stream,
936
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
937
                input_embeds=recv_req.input_embeds,
938
                custom_logit_processor=custom_logit_processor,
939
                return_hidden_states=recv_req.return_hidden_states,
940
                eos_token_ids=self.model_config.hf_eos_token_id,
941
                bootstrap_host=recv_req.bootstrap_host,
942
                bootstrap_port=recv_req.bootstrap_port,
943
                bootstrap_room=recv_req.bootstrap_room,
944
945
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
946

947
948
949
950
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
951
                req.finished_reason = FINISH_ABORT(
952
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
953
                )
954
                self._add_request_to_queue(req)
955
956
                return
        else:
957
958
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
959
            req = session.create_req(recv_req, self.tokenizer)
960
            if isinstance(req.finished_reason, FINISH_ABORT):
961
                self._add_request_to_queue(req)
962
                return
963

964
        # Handle multimodal inputs
Mick's avatar
Mick committed
965
966
        if recv_req.mm_inputs is not None:
            image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
967
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
968
            req.origin_input_ids = self.pad_input_ids_func(
969
                req.origin_input_ids, image_inputs
970
            )
971
            req.extend_image_inputs(image_inputs)
972

973
            if len(req.origin_input_ids) >= self.max_req_input_len:
974
                error_msg = (
975
                    "Multimodal prompt is too long after expanding multimodal tokens. "
976
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
977
                )
978
                logger.error(error_msg)
979
                req.origin_input_ids = [0]
Mick's avatar
Mick committed
980
                req.multimodal_inputs = None
981
                req.sampling_params.max_new_tokens = 0
982
                req.finished_reason = FINISH_ABORT(
983
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
984
                )
985
                self._add_request_to_queue(req)
986
987
                return

988
989
990
991
992
993
994
        # Validate prompts length
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
995
996
            req.origin_input_ids = [0]
            req.sampling_params.max_new_tokens = 0
997
            self._add_request_to_queue(req)
998
            return
999

1000
        # Copy more attributes
1001
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
1002
1003
1004
1005
1006
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(req.origin_input_ids) - 1
        else:
            req.logprob_start_len = recv_req.logprob_start_len

1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
        if req.logprob_start_len >= len(req.origin_input_ids):
            req.finished_reason = FINISH_ABORT(
                f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
                HTTPStatus.BAD_REQUEST,
                "BadRequestError",
            )
            req.logprob_start_len = len(req.origin_input_ids) - 1
            self._add_request_to_queue(req)
            return

1017
1018
1019
1020
1021
1022
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
1023
            self.max_req_len - len(req.origin_input_ids) - 1,
1024
1025
        )

1026
1027
1028
1029
1030
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
1031
            or req.sampling_params.ebnf is not None
1032
            or req.sampling_params.structural_tag is not None
1033
1034
1035
1036
1037
1038
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
1039
1040
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
1041
1042
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
1043
1044
1045
1046
1047
1048
1049

            req.grammar = self.grammar_backend.get_cached_value(key)
            if not req.grammar:
                req.grammar = self.grammar_backend.get_future_value(key)
                add_to_grammar_queue = True

        if add_to_grammar_queue:
1050
            req.queue_time_start = time.time()
1051
1052
            self.grammar_queue.append(req)
        else:
1053
1054
1055
            self._add_request_to_queue(req)

    def _add_request_to_queue(self, req: Req):
1056
        req.queue_time_start = time.time()
Byron Hsu's avatar
Byron Hsu committed
1057
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
Liangsheng Yin's avatar
Liangsheng Yin committed
1058
            self.disagg_prefill_bootstrap_queue.add(req)
Byron Hsu's avatar
Byron Hsu committed
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.add(req)
        else:
            self.waiting_queue.append(req)

    def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
        if self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.extend(reqs)
        else:
            self.waiting_queue.extend(reqs)
1069
1070
1071

    def handle_embedding_request(
        self,
1072
        recv_req: TokenizedEmbeddingReqInput,
1073
1074
1075
1076
1077
1078
1079
1080
1081
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

1082
1083
        # Handle multimodal inputs
        if recv_req.image_inputs is not None:
Mick's avatar
Mick committed
1084
            image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
            req.origin_input_ids = self.pad_input_ids_func(
                req.origin_input_ids, image_inputs
            )
            req.extend_image_inputs(image_inputs)

            if len(req.origin_input_ids) >= self.max_req_input_len:
                error_msg = (
                    "Multimodal prompt is too long after expanding multimodal tokens. "
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                )
                logger.error(error_msg)
                req.origin_input_ids = [0]
Mick's avatar
Mick committed
1098
                req.multimodal_inputs = None
1099
1100
1101
1102
                req.sampling_params.max_new_tokens = 0
                req.finished_reason = FINISH_ABORT(
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
                )
1103
                req.queue_time_start = time.time()
1104
1105
1106
                self.waiting_queue.append(req)
                return

1107
        # Validate prompts length
1108
        error_msg = validate_input_length(
1109
1110
1111
1112
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
1113
        if error_msg:
1114
            self._add_request_to_queue(req)
1115
            return
1116

1117
1118
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
1119
        self._add_request_to_queue(req)
1120

1121
1122
1123
1124
    def log_prefill_stats(
        self,
        adder: PrefillAdder,
        can_run_list: List[Req],
1125
        running_bs: int,
1126
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1127
1128
1129
1130
1131
        gap_latency = time.time() - self.last_prefill_stats_tic
        self.last_prefill_stats_tic = time.time()
        self.last_input_throughput = self.num_prefill_tokens / gap_latency
        self.num_prefill_tokens = 0

1132
        num_used = self.max_total_num_tokens - (
1133
1134
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
1135
        )
1136
1137
1138
        self._largest_prefill_len = max(
            self._largest_prefill_len, adder.log_input_tokens
        )
1139

1140
        num_new_seq = len(can_run_list)
1141
        f = (
1142
            f"Prefill batch. "
1143
            f"#new-seq: {num_new_seq}, "
1144
1145
1146
1147
1148
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
1149
1150
1151
1152
1153
1154
1155
1156

        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
            f += f"#queue-req: {len(self.waiting_queue)}, "
            f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)} "
        else:
            f += f"#queue-req: {len(self.waiting_queue)}"

1157
        logger.info(f)
1158
1159

        if self.enable_metrics:
1160
1161
1162
            cache_hit_rate = adder.log_hit_tokens / (
                adder.log_input_tokens + adder.log_hit_tokens
            )
1163
1164
1165
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
1166
1167
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.stats.cache_hit_rate = cache_hit_rate
1168
1169
1170
1171
1172
1173

            total_queue_latency = 0
            for req in can_run_list:
                total_queue_latency += req.queue_time_end - req.queue_time_start
            self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq

1174
1175
            self.metrics_collector.log_stats(self.stats)

1176
1177
1178
    def log_decode_stats(self, running_batch=None):
        batch = running_batch or self.running_batch

1179
1180
1181
1182
        gap_latency = time.time() - self.last_decode_stats_tic
        self.last_decode_stats_tic = time.time()
        self.last_gen_throughput = self.num_generated_tokens / gap_latency
        self.num_generated_tokens = 0
1183
        num_running_reqs = len(batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1184
        num_used = self.max_total_num_tokens - (
1185
1186
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1187
        )
1188
1189
1190
1191
1192

        if RECORD_STEP_TIME:
            self.step_time_dict[num_running_reqs].append(
                gap_latency / self.server_args.decode_log_interval
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1193

Liangsheng Yin's avatar
Liangsheng Yin committed
1194
1195
1196
1197
1198
1199
1200
        msg = (
            f"Decode batch. "
            f"#running-req: {num_running_reqs}, "
            f"#token: {num_used}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
        )

1201
        if self.spec_algorithm.is_none():
1202
            spec_accept_length = 0
1203
        else:
1204
            spec_accept_length = (
1205
1206
                self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
            )
1207
1208
            self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
            self.cum_spec_accept_count += self.spec_num_total_forward_ct
1209
            self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
1210
1211
1212
1213
1214
1215
1216
1217
1218
            msg += f"accept len: {spec_accept_length:.2f}, "

        if self.disaggregation_mode == DisaggregationMode.DECODE:
            msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "

        msg += (
            f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
            f"#queue-req: {len(self.waiting_queue)}"
        )
1219
1220

        logger.info(msg)
1221
1222
1223
1224
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
1225
1226
            self.stats.cache_hit_rate = 0.0
            self.stats.gen_throughput = self.last_gen_throughput
1227
            self.stats.num_queue_reqs = len(self.waiting_queue)
1228
            self.stats.spec_accept_length = spec_accept_length
1229
1230
            self.metrics_collector.log_stats(self.stats)

Lianmin Zheng's avatar
Lianmin Zheng committed
1231
1232
    def check_memory(self):
        available_size = (
1233
1234
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1235
        )
1236
1237
1238
1239
1240
1241
1242
        protected_size = self.tree_cache.protected_size()
        memory_leak = available_size != (
            self.max_total_num_tokens
            if not self.enable_hierarchical_cache
            else self.max_total_num_tokens - protected_size
        )
        if memory_leak:
1243
            msg = (
1244
                "token_to_kv_pool_allocator memory leak detected! "
1245
                f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1246
1247
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1248
            )
1249
1250
1251
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1252
1253

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
1254
            msg = (
1255
                "req_to_token_pool memory leak detected!"
1256
1257
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1258
            )
1259
1260
1261
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1262

1263
1264
1265
1266
1267
1268
1269
        if (
            self.enable_metrics
            and self.attn_tp_rank == 0
            and time.time() > self.metrics_collector.last_log_time + 30
        ):
            # During idle time, also collect metrics every 30 seconds.
            num_used = self.max_total_num_tokens - (
1270
                self.token_to_kv_pool_allocator.available_size()
1271
1272
                + self.tree_cache.evictable_size()
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1273
            num_running_reqs = len(self.running_batch.reqs)
1274
1275
1276
1277
1278
1279
1280
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.metrics_collector.log_stats(self.stats)

1281
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1282
        # Merge the prefill batch into the running batch
1283
1284
1285
1286
1287
1288
1289
1290
        chunked_req_to_exclude = set()
        if self.chunked_req:
            # Move the chunked request out of the batch so that we can merge
            # only finished requests to running_batch.
            chunked_req_to_exclude.add(self.chunked_req)
            self.tree_cache.cache_unfinished_req(self.chunked_req)
            # chunked request keeps its rid but will get a new req_pool_idx
            self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
Lianmin Zheng's avatar
Lianmin Zheng committed
1291
        if self.last_batch and self.last_batch.forward_mode.is_extend():
1292
1293
1294
1295
            if self.last_batch.chunked_req is not None:
                # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
                # We need to discard it.
                chunked_req_to_exclude.add(self.last_batch.chunked_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1296

1297
            # Filter batch
1298
            last_bs = self.last_batch.batch_size()
1299
1300
1301
            self.last_batch.filter_batch(
                chunked_req_to_exclude=list(chunked_req_to_exclude)
            )
1302
            if self.last_batch.batch_size() < last_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1303
                self.running_batch.batch_is_full = False
1304

1305
            # Merge the new batch into the running batch
1306
            if not self.last_batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1307
                if self.running_batch.is_empty():
1308
1309
                    self.running_batch = self.last_batch
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1310
                    # Merge running_batch with prefill batch
1311
                    self.running_batch.merge_batch(self.last_batch)
1312

1313
1314
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
1315
1316
1317
1318
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
1319
            if not self.running_batch.is_empty():
1320
                self.running_batch = self.update_running_batch(self.running_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1321
1322
1323
                ret = self.running_batch if not self.running_batch.is_empty() else None
            else:
                ret = None
1324

1325
        # Handle DP attention
1326
        if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
Lianmin Zheng's avatar
Lianmin Zheng committed
1327
            ret, _ = self.prepare_dp_attn_batch(ret)
1328
1329

        return ret
1330

1331
1332
1333
1334
1335
1336
    def get_num_allocatable_reqs(self, running_bs):
        res = global_server_args_dict["max_micro_batch_size"] - running_bs
        if self.pp_size > 1:
            res = min(res, self.req_to_token_pool.available_size())
        return res

Lianmin Zheng's avatar
Lianmin Zheng committed
1337
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
1338
        # Check if the grammar is ready in the grammar queue
1339
        if self.grammar_queue:
1340
            self.move_ready_grammar_requests()
1341

Lianmin Zheng's avatar
Lianmin Zheng committed
1342
1343
        # Handle the cases where prefill is not allowed
        if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1344
            self.running_batch.batch_is_full or len(self.waiting_queue) == 0
1345
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1346
1347
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1348
        running_bs = len(self.running_batch.reqs)
1349
1350
1351
1352
1353
1354
        # Igore the check if self.chunked_req is not None.
        # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
        # as the space for the chunked request has just been released.
        # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
        # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
        if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
Lianmin Zheng's avatar
Lianmin Zheng committed
1355
            self.running_batch.batch_is_full = True
1356
1357
            return None

1358
1359
1360
1361
1362
        if self.enable_hierarchical_cache:
            # check for completion of hierarchical cache activities to release memory
            self.tree_cache.writing_check()
            self.tree_cache.loading_check()

1363
1364
1365
        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
1366
        # Prefill policy
1367
1368
        adder = PrefillAdder(
            self.tree_cache,
1369
            self.token_to_kv_pool_allocator,
1370
1371
1372
1373
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1374
            running_bs if self.is_mixed_chunk else 0,
1375
1376
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1377
        if self.chunked_req is not None:
1378
1379
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1380

Lianmin Zheng's avatar
Lianmin Zheng committed
1381
        if self.lora_paths:
Lianmin Zheng's avatar
Lianmin Zheng committed
1382
1383
            lora_set = set([req.lora_path for req in self.running_batch.reqs])

1384
        # Get requests from the waiting queue to a new prefill batch
1385
1386
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1387
                self.lora_paths
1388
1389
1390
1391
1392
1393
1394
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1395
                self.running_batch.batch_is_full = True
1396
1397
                break

1398
            if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
Lianmin Zheng's avatar
Lianmin Zheng committed
1399
                self.running_batch.batch_is_full = True
1400
                break
1401

1402
1403
1404
1405
            req.init_next_round_input(
                None if prefix_computed else self.tree_cache,
                self.enable_hierarchical_cache,
            )
1406

1407
1408
1409
            res = adder.add_one_req(
                req, self.chunked_req, self.enable_hierarchical_cache
            )
1410

1411
1412
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1413
1414
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
Lianmin Zheng's avatar
Lianmin Zheng committed
1415
1416
                        self.running_batch.batch_is_full = len(
                            adder.can_run_list
1417
                        ) > 0 or (not self.running_batch.is_empty())
1418
                    else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1419
                        self.running_batch.batch_is_full = True
1420
1421
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1422
        # Update waiting queue
1423
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1424
1425
        if len(can_run_list) == 0:
            return None
1426
1427
1428
1429
1430
1431

        if self.enable_metrics:
            # only record queue time when enable_metrics is True to avoid overhead
            for req in can_run_list:
                req.queue_time_end = time.time()

Lianmin Zheng's avatar
Lianmin Zheng committed
1432
1433
1434
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1435

1436
        if self.enable_hierarchical_cache:
1437
            self.tree_cache.ready_to_load_cache()
1438

1439
1440
1441
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1442

1443
1444
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1445

1446
        # Print stats
1447
        if self.attn_tp_rank == 0:
1448
            self.log_prefill_stats(adder, can_run_list, running_bs)
1449

Lianmin Zheng's avatar
Lianmin Zheng committed
1450
        # Create a new batch
1451
1452
1453
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
1454
            self.token_to_kv_pool_allocator,
1455
            self.tree_cache,
1456
            self.model_config,
1457
            self.enable_overlap,
1458
            self.spec_algorithm,
1459
            self.server_args.enable_custom_logit_processor,
1460
            chunked_req=self.chunked_req,
1461
        )
1462
        new_batch.prepare_for_extend()
1463

Lianmin Zheng's avatar
Lianmin Zheng committed
1464
        # Mixed-style chunked prefill
1465
1466
        if (
            self.is_mixed_chunk
Lianmin Zheng's avatar
Lianmin Zheng committed
1467
            and not self.running_batch.is_empty()
1468
1469
1470
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1471
1472
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1473
                self.running_batch.prepare_for_decode()
1474
1475
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1476
1477
1478
            self.running_batch = ScheduleBatch(
                reqs=[], batch_is_full=self.running_batch.batch_is_full
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1479
1480
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1481
1482
1483

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1484
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1485
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1486
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1487

1488
1489
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1490
1491
            batch.batch_is_full = False
            return batch
1492

Lianmin Zheng's avatar
Lianmin Zheng committed
1493
        # Check if decode out of memory
1494
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1495
            TEST_RETRACT and batch.batch_size() > 10
1496
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1497
1498
            old_ratio = self.new_token_ratio

1499
            retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
Lianmin Zheng's avatar
Lianmin Zheng committed
1500
            self.new_token_ratio = new_token_ratio
1501

Lianmin Zheng's avatar
Lianmin Zheng committed
1502
1503
1504
1505
1506
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
1507
            self._extend_requests_to_queue(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1508
1509
        else:
            self.new_token_ratio = max(
1510
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1511
1512
1513
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1514
        if batch.batch_size() < initial_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1515
            batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1516
1517

        # Update batch tensors
1518
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1519
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1520

1521
1522
1523
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1524
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1525
1526
        self.forward_ct += 1

1527
1528
1529
1530
1531
1532
1533
        # Check profiler
        if (
            self.profiler_target_forward_ct
            and self.profiler_target_forward_ct <= self.forward_ct
        ):
            self.stop_profile()

1534
1535
1536
1537
        if self.forward_sleep_time is not None:
            logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
            time.sleep(self.forward_sleep_time)

1538
        # Run forward
1539
        if self.is_generation:
1540
1541
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
1542
1543
1544
1545
1546
1547
1548
1549
                if self.pp_group.is_last_rank:
                    logits_output, next_token_ids = (
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
                else:
                    pp_hidden_states_proxy_tensors, _ = (
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
1550
                bid = model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
1551
            else:
1552
1553
1554
                (
                    logits_output,
                    next_token_ids,
1555
                    bid,
1556
1557
1558
1559
1560
1561
1562
                    num_accepted_tokens,
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
                self.spec_num_total_accepted_tokens += (
                    num_accepted_tokens + batch.batch_size()
                )
                self.spec_num_total_forward_ct += batch.batch_size()
                self.num_generated_tokens += num_accepted_tokens
1563
1564
1565

            if self.pp_group.is_last_rank:
                batch.output_ids = next_token_ids
1566

1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
            if batch.return_logprob:
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_input_len_per_req = None
                extend_logprob_start_len_per_req = None

1579
            ret = GenerationBatchResult(
1580
1581
1582
1583
1584
1585
1586
                logits_output=logits_output if self.pp_group.is_last_rank else None,
                pp_hidden_states_proxy_tensors=(
                    pp_hidden_states_proxy_tensors
                    if not self.pp_group.is_last_rank
                    else None
                ),
                next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
1587
1588
                extend_input_len_per_req=extend_input_len_per_req,
                extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1589
                bid=bid,
1590
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1591
1592
1593
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1594
1595
1596
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1597
        return ret
Chayenne's avatar
Chayenne committed
1598

1599
1600
1601
1602
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
1603
        launch_done: Optional[threading.Event] = None,
1604
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1605
        if batch.forward_mode.is_decode():
1606
            self.process_batch_result_decode(batch, result, launch_done)
1607
        elif batch.forward_mode.is_extend():
1608
            self.process_batch_result_prefill(batch, result, launch_done)
1609
1610
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1611
                self.tp_worker.resolve_last_batch_result(launch_done)
1612
1613
1614
1615
                if batch.next_batch_sampling_info:
                    batch.next_batch_sampling_info.update_regex_vocab_mask()
                    self.current_stream.synchronize()
                    batch.next_batch_sampling_info.sampling_info_done.set()
1616
1617
        elif batch.forward_mode.is_dummy_first():
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1618
            self.current_stream.synchronize()
1619
            batch.next_batch_sampling_info.sampling_info_done.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
1620

1621
1622
1623
1624
1625
1626
1627
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

1628
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
        return self.prepare_dp_attn_batch_raw(
            local_batch,
            dp_size=self.server_args.dp_size,
            attn_tp_size=self.attn_tp_size,
            tp_cpu_group=self.tp_cpu_group,
            get_idle_batch=self.get_idle_batch,
            disable_cuda_graph=self.server_args.disable_cuda_graph,
            spec_algorithm=self.spec_algorithm,
            speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
        )

    @staticmethod
    def prepare_dp_attn_batch_raw(
        local_batch: ScheduleBatch,
        dp_size,
        attn_tp_size: int,
        tp_cpu_group,
        get_idle_batch,
        disable_cuda_graph: bool,
        spec_algorithm,
        speculative_num_draft_tokens,
    ):
1651
1652
1653
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1654
            global_num_tokens_for_logprob = 0
1655
1656
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
1657
1658
            if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
                num_tokens = num_tokens * speculative_num_draft_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
1659
            global_num_tokens_for_logprob = num_tokens
1660
1661
        else:
            num_tokens = local_batch.extend_num_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
            global_num_tokens_for_logprob = sum(
                [
                    # We should have at least 1 token for sample in every case.
                    max(extend_len - logprob_start_len, 1)
                    for logprob_start_len, extend_len in zip(
                        local_batch.extend_logprob_start_lens, local_batch.extend_lens
                    )
                ]
            )

        if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
            can_cuda_graph = 1
        else:
            can_cuda_graph = 0

1677
        if not spec_algorithm.is_none():
Lianmin Zheng's avatar
Lianmin Zheng committed
1678
1679
1680
            # TODO(sang): Support cuda graph when idle batch is there.
            if local_batch is None or local_batch.forward_mode.is_idle():
                can_cuda_graph = 0
1681

Lianmin Zheng's avatar
Lianmin Zheng committed
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
        is_extend_in_batch = (
            local_batch.forward_mode.is_extend() if local_batch else False
        )
        local_info = torch.tensor(
            [
                num_tokens,
                can_cuda_graph,
                global_num_tokens_for_logprob,
                is_extend_in_batch,
            ],
            dtype=torch.int64,
        )
        global_info = torch.empty(
1695
            (dp_size, attn_tp_size, 4),
Lianmin Zheng's avatar
Lianmin Zheng committed
1696
1697
            dtype=torch.int64,
        )
1698
        torch.distributed.all_gather_into_tensor(
Lianmin Zheng's avatar
Lianmin Zheng committed
1699
1700
            global_info.flatten(),
            local_info,
1701
            group=tp_cpu_group,
1702
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1703
1704
1705
1706
        global_num_tokens = global_info[:, 0, 0].tolist()
        can_cuda_graph = min(global_info[:, 0, 1].tolist())
        global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
        is_extend_in_batch = global_info[:, 0, 3].tolist()
1707

Lianmin Zheng's avatar
Lianmin Zheng committed
1708
        if local_batch is None and max(global_num_tokens) > 0:
1709
            local_batch = get_idle_batch()
1710
1711

        if local_batch is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1712
1713
            local_batch.global_num_tokens = global_num_tokens
            local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
1714
1715

            # Check forward mode for cuda graph
1716
            if not disable_cuda_graph:
Lianmin Zheng's avatar
Lianmin Zheng committed
1717
                local_batch.can_run_dp_cuda_graph = can_cuda_graph
1718

Lianmin Zheng's avatar
Lianmin Zheng committed
1719
        return local_batch, any(is_extend_in_batch)
1720
1721
1722
1723
1724

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
1725
            self.token_to_kv_pool_allocator,
1726
1727
1728
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
1729
            self.spec_algorithm,
1730
            self.server_args.enable_custom_logit_processor,
1731
1732
1733
1734
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1735
1736
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
1737

1738
1739
1740
1741
1742
1743
1744
1745
        num_ready_reqs = 0
        for req in self.grammar_queue:
            try:
                req.grammar = req.grammar.result(timeout=0.05)
                num_ready_reqs += 1
            except futures._base.TimeoutError:
                break

1746
        if self.server_args.enable_dp_attention:
1747
1748
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
1749
        else:
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
            tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
            num_ready_reqs_max = tensor.item()
            for i in range(num_ready_reqs, num_ready_reqs_max):
                self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
            num_ready_reqs = num_ready_reqs_max
1763

1764
        self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1765
1766
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
    def watchdog_thread(self):
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
        self.watchdog_last_forward_ct = 0
        self.watchdog_last_time = time.time()

        while True:
            current = time.time()
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if current > self.watchdog_last_time + self.watchdog_timeout:
                        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)

        # Print batch size and memory pool info to check whether there are de-sync issues.
        logger.error(
            f"{self.cur_batch.batch_size()=}, "
            f"{self.cur_batch.reqs=}, "
            f"{self.token_to_kv_pool_allocator.available_size()=}, "
            f"{self.tree_cache.evictable_size()=}, "
        )
        # Wait for some time so that the parent process can print the error.
        pyspy_dump_schedulers()
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
        time.sleep(5)
        self.parent_process.send_signal(signal.SIGQUIT)

1798
1799
1800
    def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
        success = self.flush_cache()
        return FlushCacheReqOutput(success=success)
1801

1802
    def flush_cache(self):
1803
        """Flush the memory pool and cache."""
1804
1805
1806
1807
1808
        if (
            len(self.waiting_queue) == 0
            and self.running_batch.is_empty()
            and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
        ):
1809
1810
            self.cur_batch = None
            self.last_batch = None
1811
            self.tree_cache.reset()
1812
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1813
                self.grammar_backend.reset()
1814
            self.req_to_token_pool.clear()
1815
            self.token_to_kv_pool_allocator.clear()
1816
1817
1818

            if not self.spec_algorithm.is_none():
                self.draft_worker.model_runner.req_to_token_pool.clear()
1819
                self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
1820
1821
1822
1823
1824

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
1825
1826
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
1827
1828
1829
1830
1831
1832
1833
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
1834
                f"#running-req: {len(self.running_batch.reqs)}"
1835
1836
1837
1838
            )
            if_success = False
        return if_success

1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )
        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
        return GetInternalStateReqOutput(
            internal_state=ret,
        )

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
1856
                "max_micro_batch_size",
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
1867
1868
1869
1870
1871
1872
1873
1874
            elif k == "max_micro_batch_size" and (
                v > self.max_running_requests // self.pp_size or v < 1
            ):
                logging.warning(
                    f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
                )
                if_success = False
                break
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
            logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
    def handle_rpc_request(self, recv_req: RpcReqInput):
        # Handle RPC requests
        logger.info(
            f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
        )

        success = True
        exec = None
        try:
            func = getattr(self, recv_req.method)
            func(recv_req.parameters)
        except Exception as e:
            success = False
            exec = e
            logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")

        barrier()
        return RpcReqOutput(success, "" if not exec else str(exec))

    def save_remote_model(self, params):
        url = params["url"]

1912
        worker = self.tp_worker.worker
1913
1914
1915
1916

        worker.model_runner.save_remote_model(url)

    def save_sharded_model(self, params):
1917
        worker = self.tp_worker.worker
1918
1919
1920
1921
1922
1923
1924

        worker.model_runner.save_sharded_model(
            path=params["path"],
            pattern=params["pattern"],
            max_size=params["max_size"],
        )

1925
1926
    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
Lianmin Zheng's avatar
Lianmin Zheng committed
1927
        to_del = []
1928
        for i, req in enumerate(self.waiting_queue):
Lianmin Zheng's avatar
Lianmin Zheng committed
1929
1930
            if req.rid.startswith(recv_req.rid):
                to_del.append(i)
1931
1932
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1933
1934
1935
        # Sort in reverse order to avoid index issues when deleting
        for i in sorted(to_del, reverse=True):
            req = self.waiting_queue.pop(i)
1936
1937
            logger.debug(f"Abort queued request. {req.rid=}")
            return
1938
1939

        # Delete requests in the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1940
1941
1942
1943
1944
        for req in self.running_batch.reqs:
            if req.rid.startswith(recv_req.rid) and not req.finished():
                logger.debug(f"Abort running request. {req.rid=}")
                req.to_abort = True
                return
1945

1946
1947
1948
    def _pause_engine(self) -> Tuple[List[Req], int]:
        raise NotImplementedError()

Chayenne's avatar
Chayenne committed
1949
1950
1951
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
1952
1953
1954
1955
1956
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
1957
        return UpdateWeightFromDiskReqOutput(success, message, 0)
1958

1959
1960
1961
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
1962
        return InitWeightsUpdateGroupReqOutput(success, message)
1963
1964

    def update_weights_from_distributed(
1965
1966
1967
        self,
        recv_req: UpdateWeightsFromDistributedReqInput,
    ) -> Tuple[bool, str]:
1968
1969
1970
1971
1972
1973
1974
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
1975
        return UpdateWeightsFromDistributedReqOutput(success, message)
1976

1977
1978
1979
1980
1981
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        """Update the online model parameter from tensors."""
        success, message = self.tp_worker.update_weights_from_tensor(recv_req)
        # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
        if success:
1982
1983
1984
            if recv_req.flush_cache:
                flash_cache_success = self.flush_cache()
                assert flash_cache_success, "Cache flush failed after updating weights"
1985
1986
        else:
            logger.error(message)
1987
        return UpdateWeightsFromTensorReqOutput(success, message)
1988

1989
1990
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
1991
        return GetWeightsByNameReqOutput(parameter)
1992

1993
    def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
1994
1995
1996
        self.memory_saver_adapter.check_validity(
            caller_name="release_memory_occupation"
        )
1997
1998
1999
2000
2001
        self.stashed_model_static_state = _export_static_state(
            self.tp_worker.worker.model_runner.model
        )
        self.memory_saver_adapter.pause()
        self.flush_cache()
2002
        return ReleaseMemoryOccupationReqOutput()
2003

2004
    def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
2005
        self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
2006
2007
2008
2009
2010
        self.memory_saver_adapter.resume()
        _import_static_state(
            self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
        )
        del self.stashed_model_static_state
2011
2012
        return ResumeMemoryOccupationReqOutput()

2013
2014
2015
2016
2017
2018
2019
    def slow_down(self, recv_req: SlowDownReqInput):
        t = recv_req.forward_sleep_time
        if t is not None and t <= 0:
            t = None
        self.forward_sleep_time = t
        return SlowDownReqOutput()

2020
    def profile(self, recv_req: ProfileReq):
2021
2022
        if recv_req.type == ProfileReqType.START_PROFILE:
            return self.start_profile(
2023
2024
2025
2026
2027
                recv_req.output_dir,
                recv_req.num_steps,
                recv_req.activities,
                recv_req.with_stack,
                recv_req.record_shapes,
2028
                recv_req.profile_id,
2029
            )
2030
        else:
2031
2032
2033
2034
2035
2036
2037
            return self.stop_profile()

    def start_profile(
        self,
        output_dir: Optional[str],
        num_steps: Optional[int],
        activities: Optional[List[str]],
2038
2039
        with_stack: Optional[bool],
        record_shapes: Optional[bool],
2040
        profile_id: Optional[str],
2041
    ) -> None:
2042
        if self.profiler_activities:
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
            return ProfileReqOutput(
                success=False,
                message="Profiling is already in progress. Call /stop_profile first.",
            )

        if output_dir is None:
            output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
        if activities is None:
            activities = ["CPU", "GPU"]

        self.torch_profiler_output_dir = output_dir
2054
        self.profiler_activities = activities
2055
        self.profiler_id = profile_id
2056
        logger.info(
2057
            "Profiling starts. Traces will be saved to: %s (with id %s)",
2058
            self.torch_profiler_output_dir,
2059
            self.profiler_id,
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
        )

        activity_map = {
            "CPU": torch.profiler.ProfilerActivity.CPU,
            "GPU": torch.profiler.ProfilerActivity.CUDA,
        }
        torchprof_activities = [
            activity_map[a] for a in activities if a in activity_map
        ]

        if torchprof_activities:
            self.torch_profiler = torch.profiler.profile(
                activities=torchprof_activities,
2073
2074
                with_stack=with_stack if with_stack is not None else True,
                record_shapes=record_shapes if record_shapes is not None else False,
2075
2076
2077
2078
2079
            )
            self.torch_profiler.start()

        if "MEM" in activities:
            torch.cuda.memory._record_memory_history(max_entries=100000)
2080

2081
2082
2083
        if "CUDA_PROFILER" in activities:
            torch.cuda.cudart().cudaProfilerStart()

2084
2085
2086
2087
2088
2089
        if num_steps:
            self.profiler_target_forward_ct = self.forward_ct + num_steps
            # The caller will be notified when reaching profiler_target_forward_ct
        else:
            self.profiler_target_forward_ct = None
            return ProfileReqOutput(success=True, message="Succeeded")
2090
2091

    def stop_profile(self) -> None:
2092
        if self.profiler_activities is None:
2093
2094
2095
2096
2097
2098
2099
2100
            return

        logger.info("Stop profiling...")
        if self.torch_profiler is not None:
            self.torch_profiler.stop()
            self.torch_profiler.export_chrome_trace(
                os.path.join(
                    self.torch_profiler_output_dir,
2101
                    self.profiler_id + f"-TP-{self.tp_rank}" + ".trace.json.gz",
2102
2103
2104
                )
            )

2105
        if "MEM" in self.profiler_activities:
2106
            memory_profile_path = os.path.join(
2107
                self.torch_profiler_output_dir,
2108
                self.profiler_id + f"-TP-{self.tp_rank}-memory" + ".pickle",
2109
2110
2111
2112
            )
            torch.cuda.memory._dump_snapshot(memory_profile_path)
            torch.cuda.memory._record_memory_history(enabled=None)

2113
2114
2115
        if "CUDA_PROFILER" in self.profiler_activities:
            torch.cuda.cudart().cudaProfilerStop()

2116
2117
2118
        logger.info(
            "Profiling done. Traces are saved to: %s",
            self.torch_profiler_output_dir,
2119
        )
2120
2121
        self.torch_profiler = None
        self.torch_profiler_output_dir = None
2122
        self.profiler_activities = None
2123
2124
2125
2126
2127

        if self.profiler_target_forward_ct:
            self.send_to_tokenizer.send_pyobj(
                ProfileReqOutput(success=True, message="Succeeded.")
            )
2128

2129
2130
2131
2132
2133
2134
2135
2136
2137
    def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
        if recv_req == ExpertDistributionReq.START_RECORD:
            expert_distribution_recorder.start_record()
        elif recv_req == ExpertDistributionReq.STOP_RECORD:
            expert_distribution_recorder.stop_record()
        elif recv_req == ExpertDistributionReq.DUMP_RECORD:
            expert_distribution_recorder.dump_record()
        else:
            raise ValueError("Unrecognized ExpertDistributionReq value")
2138
        return ExpertDistributionReqOutput()
2139

2140
    def open_session(self, recv_req: OpenSessionReqInput):
2141
2142
2143
2144
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
2145
            return OpenSessionReqOutput(session_id, False)
2146
        elif session_id is None:
2147
            logger.warning("session id is None, cannot open.")
2148
            return OpenSessionReqOutput(session_id, False)
2149
2150
2151
2152
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
2153
            return OpenSessionReqOutput(session_id, True)
2154
2155
2156
2157
2158
2159
2160
2161
2162

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
    def get_print_prefix(self):
        prefix = ""
        if self.dp_rank is not None:
            prefix += f" DP{self.dp_rank}"
        if self.server_args.tp_size > 1:
            prefix += f" TP{self.tp_rank}"
        if self.pp_size > 1:
            prefix += f" PP{self.pp_rank}"
        return prefix

2173

2174
2175
2176
2177
def is_health_check_generate_req(recv_req):
    return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")


2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
def _export_static_state(model):
    return dict(
        buffers=[
            (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
        ]
    )


def _import_static_state(model, static_params):
    self_named_buffers = dict(model.named_buffers())
    for name, tensor in static_params["buffers"]:
        self_named_buffers[name][...] = tensor


2192
2193
2194
2195
2196
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
2197
    pp_rank: int,
2198
    dp_rank: Optional[int],
2199
    pipe_writer,
2200
):
2201
    # Generate the prefix
2202
2203
2204
2205
2206
2207
2208
    prefix = ""
    if dp_rank is not None:
        prefix += f" DP{dp_rank}"
    if server_args.tp_size > 1:
        prefix += f" TP{tp_rank}"
    if server_args.pp_size > 1:
        prefix += f" PP{pp_rank}"
2209

2210
    # Config the process
2211
    kill_itself_when_parent_died()
2212
    setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
2213
    faulthandler.enable()
2214
    parent_process = psutil.Process().parent()
2215

2216
2217
2218
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
2219

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
2220
    # Configure the logger
2221
    configure_logger(server_args, prefix=prefix)
2222
    suppress_other_loggers()
2223

2224
    # Set cpu affinity to this gpu process
2225
2226
2227
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

2228
    # Create a scheduler and run the event loop
2229
    try:
2230
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
2231
        pipe_writer.send(
Mick's avatar
Mick committed
2232
2233
2234
2235
2236
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2237
        )
Byron Hsu's avatar
Byron Hsu committed
2238
2239
2240
        disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode

        if disaggregation_mode == DisaggregationMode.NULL:
2241
2242
2243
            if server_args.pp_size > 1:
                scheduler.event_loop_pp()
            elif scheduler.enable_overlap:
Byron Hsu's avatar
Byron Hsu committed
2244
2245
2246
2247
                scheduler.event_loop_overlap()
            else:
                scheduler.event_loop_normal()
        elif disaggregation_mode == DisaggregationMode.PREFILL:
2248
2249
2250
2251
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_prefill()
            else:
                scheduler.event_loop_normal_disagg_prefill()
2252

Byron Hsu's avatar
Byron Hsu committed
2253
        elif disaggregation_mode == DisaggregationMode.DECODE:
2254
2255
2256
2257
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_decode()
            else:
                scheduler.event_loop_normal_disagg_decode()
Byron Hsu's avatar
Byron Hsu committed
2258

2259
    except Exception:
2260
2261
2262
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)