scheduler.py 78.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
23
import time
import warnings
24
from collections import defaultdict, deque
Lianmin Zheng's avatar
Lianmin Zheng committed
25
from concurrent import futures
26
from dataclasses import dataclass
27
from http import HTTPStatus
28
from types import SimpleNamespace
29
from typing import Dict, List, Optional, Tuple, Union
30

31
import psutil
32
import setproctitle
33
import torch
34
import zmq
35
from torch.distributed import barrier
36

37
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
38
from sglang.srt.configs.model_config import ModelConfig
39
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
Byron Hsu's avatar
Byron Hsu committed
40
41
42
43
44
45
46
47
48
49
50
51
from sglang.srt.disaggregation.decode import (
    DecodePreallocQueue,
    DecodeTransferQueue,
    SchedulerDisaggregationDecodeMixin,
)
from sglang.srt.disaggregation.prefill import (
    PrefillBootstrapQueue,
    SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
    ReqToMetadataIdxAllocator,
52
    TransferBackend,
Byron Hsu's avatar
Byron Hsu committed
53
)
54
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
55
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
56
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
57
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
58
59
from sglang.srt.managers.io_struct import (
    AbortReq,
60
    CloseSessionReqInput,
61
    ExpertDistributionReq,
62
    ExpertDistributionReqOutput,
63
64
    FlushCacheReqInput,
    FlushCacheReqOutput,
65
66
    GetInternalStateReq,
    GetInternalStateReqOutput,
67
68
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
69
    HealthCheckOutput,
70
71
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
72
73
    OpenSessionReqInput,
    OpenSessionReqOutput,
74
    ProfileReq,
75
76
    ProfileReqOutput,
    ProfileReqType,
77
78
79
80
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
81
82
    RpcReqInput,
    RpcReqOutput,
83
84
    SetInternalStateReq,
    SetInternalStateReqOutput,
85
86
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
87
88
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
89
90
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
91
92
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
93
94
95
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
Mick's avatar
Mick committed
96
    MultimodalInputs,
97
98
    Req,
    ScheduleBatch,
99
    global_server_args_dict,
100
)
101
102
103
104
105
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
106
107
108
from sglang.srt.managers.scheduler_output_processor_mixin import (
    SchedulerOutputProcessorMixin,
)
109
from sglang.srt.managers.session_controller import Session
110
from sglang.srt.managers.tp_worker import TpModelWorker
111
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
112
from sglang.srt.managers.utils import validate_input_length
113
from sglang.srt.mem_cache.chunk_cache import ChunkCache
114
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
115
from sglang.srt.mem_cache.radix_cache import RadixCache
116
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
Mick's avatar
Mick committed
117
from sglang.srt.model_executor.forward_batch_info import ForwardMode
118
from sglang.srt.reasoning_parser import ReasoningParser
119
from sglang.srt.server_args import PortArgs, ServerArgs
120
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
121
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
122
from sglang.srt.utils import (
123
    DynamicGradMode,
124
125
    broadcast_pyobj,
    configure_logger,
126
    crash_on_warnings,
127
    get_bool_env_var,
128
    get_zmq_socket,
Lianmin Zheng's avatar
Lianmin Zheng committed
129
    kill_itself_when_parent_died,
130
    pyspy_dump_schedulers,
131
    set_gpu_proc_affinity,
132
133
134
    set_random_seed,
    suppress_other_loggers,
)
135
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
136

137
138
expert_distribution_recorder = ExpertDistributionRecorder()

139
140
logger = logging.getLogger(__name__)

141
# Test retract decode for debugging purposes
142
143
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
144

145

146
147
148
149
@dataclass
class GenerationBatchResult:
    logits_output: LogitsProcessorOutput
    next_token_ids: List[int]
150
151
    extend_input_len_per_req: List[int]
    extend_logprob_start_len_per_req: List[int]
152
153
154
155
156
157
158
159
160
    bid: int


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


Byron Hsu's avatar
Byron Hsu committed
161
162
163
164
165
class Scheduler(
    SchedulerOutputProcessorMixin,
    SchedulerDisaggregationDecodeMixin,
    SchedulerDisaggregationPrefillMixin,
):
166
167
168
169
170
171
172
173
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
174
        dp_rank: Optional[int],
175
176
    ):
        # Parse args
177
        self.server_args = server_args
178
179
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
180
181
182
        self.schedule_policy = server_args.schedule_policy
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
183
        self.enable_overlap = not server_args.disable_overlap_schedule
184
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
185
        self.enable_metrics = server_args.enable_metrics
186
        self.stream_interval = server_args.stream_interval
187
188
189
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
190
191
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
Lianmin Zheng's avatar
Lianmin Zheng committed
192
        self.page_size = server_args.page_size
193

194
        # Distributed rank info
195
196
197
198
199
200
201
202
203
204
        self.dp_size = server_args.dp_size
        self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

205
206
        # Init inter-process communication
        context = zmq.Context(2)
207
        if self.attn_tp_rank == 0:
208
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
209
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
210
            )
211
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
212
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
213
            )
214

215
            if server_args.skip_tokenizer_init:
216
                # Directly send to the TokenizerManager
217
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
218
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
219
220
                )
            else:
221
                # Send to the DetokenizerManager
222
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
223
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
224
                )
225
226
227
228

            self.recv_from_rpc = get_zmq_socket(
                context, zmq.DEALER, port_args.rpc_ipc_name, False
            )
229
        else:
230
            self.recv_from_tokenizer = None
231
            self.recv_from_rpc = None
232
233
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
234
235

        # Init tokenizer
236
        self.init_tokenizer()
237

238
239
240
241
242
243
244
245
246
        # Set reasoning_parser and think_end_id if --reasoning_parser is enabled
        if self.server_args.reasoning_parser and self.tokenizer:
            reasoning_parser = ReasoningParser(
                model_type=self.server_args.reasoning_parser, stream_reasoning=False
            )
            self.tokenizer.think_end_id = self.tokenizer.encode(
                reasoning_parser.detector.think_end_token, add_special_tokens=False
            )[0]

247
248
249
250
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
251
252
253
254
        if self.model_config.is_multimodal:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for multimodal models.")

255
        # Launch a tensor parallel worker
256
        if self.enable_overlap:
257
            TpWorkerClass = TpModelWorkerClient
258
259
        else:
            TpWorkerClass = TpModelWorker
260

261
        self.tp_worker = TpWorkerClass(
262
            server_args=server_args,
263
264
            gpu_id=gpu_id,
            tp_rank=tp_rank,
265
            dp_rank=dp_rank,
266
            nccl_port=port_args.nccl_port,
267
        )
268

269
        # Launch a draft worker for speculative decoding
270
271
272
273
274
275
276
277
278
279
280
281
282
283
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
        else:
            self.draft_worker = None

284
        # Get token and memory info from the model worker
285
286
287
288
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
289
            self.max_req_len,
290
291
            self.max_req_input_len,
            self.random_seed,
292
            self.device,
293
294
295
296
297
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
298
        self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
299
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
300
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
301
        global_server_args_dict.update(worker_global_server_args_dict)
302
        set_random_seed(self.random_seed)
303

304
305
306
        # Print debug info
        logger.info(
            f"max_total_num_tokens={self.max_total_num_tokens}, "
307
            f"chunked_prefill_size={server_args.chunked_prefill_size}, "
308
309
310
311
312
            f"max_prefill_tokens={self.max_prefill_tokens}, "
            f"max_running_requests={self.max_running_requests}, "
            f"context_len={self.model_config.context_len}"
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
313
        # Init memory pool and cache
314
        self.init_memory_pool_and_cache()
315
316
317

        # Init running status
        self.waiting_queue: List[Req] = []
318
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
319
        self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
320
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
321
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
322
        # The last forward batch
323
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
324
325
        self.forward_ct = 0
        self.forward_ct_decode = 0
326
        self.num_generated_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
327
        self.num_prefill_tokens = 0
328
        self.last_decode_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
329
        self.last_prefill_stats_tic = time.time()
330
        self.return_health_check_ct = 0
331
        self.current_stream = torch.get_device_module(self.device).current_stream()
332
333
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
334

335
        # Init session info
336
        self.sessions: Dict[str, Session] = {}
337
338
339

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
340
341
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
342
        self.chunked_req = None
343
344
345
346
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
347
        # Init the grammar backend for constrained generation
348
        self.grammar_queue: List[Req] = []
349
        if not server_args.skip_tokenizer_init:
350
351
352
            self.grammar_backend = create_grammar_backend(
                server_args, self.tokenizer, self.model_config.vocab_size
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
353
354
        else:
            self.grammar_backend = None
355

356
        # Init schedule policy and new token estimation
357
        self.policy = SchedulePolicy(
Lianmin Zheng's avatar
Lianmin Zheng committed
358
359
360
            self.schedule_policy,
            self.tree_cache,
            self.enable_hierarchical_cache,
361
        )
362
363
364
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
365
366
        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
367
368
            * server_args.schedule_conservativeness,
            1.0,
369
        )
370
371
372
373
374
375
376
377
378
379
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
380
381
382
383
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
384
        self.parent_process = psutil.Process().parent()
Lianmin Zheng's avatar
Lianmin Zheng committed
385

386
        # Init memory saver
387
388
389
390
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )

391
        # Init profiler
392
393
        self.torch_profiler = None
        self.torch_profiler_output_dir: Optional[str] = None
394
        self.profiler_activities: Optional[List[str]] = None
395
        self.profiler_id: Optional[str] = None
396
        self.profiler_target_forward_ct: Optional[int] = None
397

398
        # Init metrics stats
399
        self.init_metrics()
400

401
402
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
403
404
405
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
406
                (FlushCacheReqInput, self.flush_cache_wrapped),
407
                (AbortReq, self.abort_request),
408
409
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
410
411
412
413
414
415
416
417
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
418
419
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
420
                (ProfileReq, self.profile),
421
                (GetInternalStateReq, self.get_internal_state),
422
                (SetInternalStateReq, self.set_internal_state),
423
                (RpcReqInput, self.handle_rpc_request),
424
                (ExpertDistributionReq, self.expert_distribution_handle),
425
426
427
            ]
        )

Byron Hsu's avatar
Byron Hsu committed
428
429
430
431
432
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
        self.init_disaggregation()

433
434
    def init_tokenizer(self):
        server_args = self.server_args
Lianmin Zheng's avatar
Lianmin Zheng committed
435

436
437
438
439
440
441
442
        self.model_config = ModelConfig(
            server_args.model_path,
            trust_remote_code=server_args.trust_remote_code,
            revision=server_args.revision,
            context_length=server_args.context_length,
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
443
            enable_multimodal=server_args.enable_multimodal,
444
445
446
447
            dtype=server_args.dtype,
            quantization=server_args.quantization,
        )
        self.is_generation = self.model_config.is_generation
448

449
450
451
452
453
454
455
456
457
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if self.model_config.is_multimodal:
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
458
                    use_fast=not server_args.disable_fast_image_processor,
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )

    def init_memory_pool_and_cache(self):
        server_args = self.server_args

        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            self.tp_worker.get_memory_pool()
        )

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
            )
        else:
            if self.enable_hierarchical_cache:
                self.tree_cache = HiRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
489
                    tp_cache_group=self.tp_cpu_group,
490
                    page_size=self.page_size,
491
                    hicache_ratio=server_args.hicache_ratio,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
492
493
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
494
495
496
497
498
                )
            else:
                self.tree_cache = RadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
499
                    page_size=self.page_size,
500
501
502
503
504
505
506
507
508
509
510
511
512
                    disable=server_args.disable_radix_cache,
                )

        self.decode_mem_cache_buf_multiplier = (
            1
            if self.spec_algorithm.is_none()
            else (
                server_args.speculative_num_draft_tokens
                + (
                    server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                )
            )
513
        )
514
515
516
517
518
519
520

    def init_metrics(self):
        # The largest prefill length of a single request
        self._largest_prefill_len: int = 0
        # The largest context length (prefill + generation) of a single request
        self._largest_prefill_decode_len: int = 0
        self.last_gen_throughput: float = 0.0
Lianmin Zheng's avatar
Lianmin Zheng committed
521
        self.last_input_throughput: float = 0.0
522
523
524
525
526
527
528
529
530
531
532
533
534
535
        self.step_time_dict = defaultdict(list)  # Dict[batch size -> step time]
        self.spec_num_total_accepted_tokens = 0
        self.spec_num_total_forward_ct = 0
        self.cum_spec_accept_length = 0
        self.cum_spec_accept_count = 0
        self.stats = SchedulerStats()
        if self.enable_metrics:
            engine_type = "unified"
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    "engine_type": engine_type,
                },
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
536

Byron Hsu's avatar
Byron Hsu committed
537
    def init_disaggregation(self):
538
539
540
541
        self.transfer_backend = TransferBackend(
            self.server_args.disaggregation_transfer_backend
        )

Byron Hsu's avatar
Byron Hsu committed
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
        if (
            self.disaggregation_mode == DisaggregationMode.DECODE
        ):  # *2 for the headroom.
            buffer_size = (self.req_to_token_pool.size) * 2
            req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
                buffer_size
            )
            aux_dtype = torch.int32
            # A list of metadata buffers. The shape is (b, metadata_size) where
            # b corresponds to a max running requests. The last shape * dtype.itemsize
            # should be larger than 64 bytes to work with RDMA, so we pad it.
            output_id_buffer = torch.zeros(
                (buffer_size, 16), dtype=aux_dtype, device="cpu"
            )
            metadata_buffers = [output_id_buffer]

            # The decode requests polling kv cache
            self.disagg_decode_transfer_queue = DecodeTransferQueue(
560
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
561
562
563
564
565
566
567
568
569
570
571
572
573
574
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
                metadata_buffers=metadata_buffers,
            )

            # The decode requests pending for pre-allocation
            self.disagg_decode_prealloc_queue = DecodePreallocQueue(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
                metadata_buffers=metadata_buffers,
                aux_dtype=aux_dtype,
                scheduler=self,
                transfer_queue=self.disagg_decode_transfer_queue,
                tree_cache=self.tree_cache,
575
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
576
577
578
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
579
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
580
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
581
582
583
584

            # Metric for pre-allocation
            self.num_tokens_pre_allocated = 0

Byron Hsu's avatar
Byron Hsu committed
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
        elif self.disaggregation_mode == DisaggregationMode.PREFILL:
            # *2 for the headroom.
            buffer_size = self.max_running_requests * 2
            req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
                buffer_size
            )
            aux_dtype = torch.int32
            # A list of metadata buffers. The shape is (b, metadata_size) where
            # b corresponds to a max running requests. The last shape * dtype.itemsize
            # should be larger than 64 bytes to work with RDMA, so we pad it.
            output_id_buffer = torch.zeros(
                (buffer_size, 16), dtype=aux_dtype, device="cpu"
            )
            metadata_buffers = [output_id_buffer]

Liangsheng Yin's avatar
Liangsheng Yin committed
600
            self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
Byron Hsu's avatar
Byron Hsu committed
601
602
603
604
605
606
607
                token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
                metadata_buffers=metadata_buffers,
                aux_dtype=aux_dtype,
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
608
                gloo_group=self.attn_tp_cpu_group,
609
                transfer_backend=self.transfer_backend,
610
                scheduler=self,
Byron Hsu's avatar
Byron Hsu committed
611
612
            )
            # The prefill requests that are in the middle of kv sending
613
            self.disagg_prefill_inflight_queue: List[Req] = []
Byron Hsu's avatar
Byron Hsu committed
614

615
    @DynamicGradMode()
616
    def event_loop_normal(self):
617
        """A normal scheduler loop."""
618
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
619
620
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
621

622
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
623
            self.cur_batch = batch
624
625
626
627

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
628
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
629
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
630
                self.check_memory()
631
                self.new_token_ratio = self.init_new_token_ratio
632
633

            self.last_batch = batch
634

635
    @DynamicGradMode()
Lianmin Zheng's avatar
Lianmin Zheng committed
636
    def event_loop_overlap(self):
637
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
638
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
639
640
641
642
643
644
645

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
646

Lianmin Zheng's avatar
Lianmin Zheng committed
647
648
            if batch:
                result = self.run_batch(batch)
649
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
650

651
                if self.last_batch is None:
652
                    # Create a dummy first batch to start the pipeline for overlap schedule.
653
654
655
656
657
658
659
660
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
                    self.process_batch_result(tmp_batch, None)

Lianmin Zheng's avatar
Lianmin Zheng committed
661
            if self.last_batch:
662
                # Process the results of the last batch
663
                tmp_batch, tmp_result = self.result_queue.popleft()
664
665
666
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
667
668
                self.process_batch_result(tmp_batch, tmp_result)
            elif batch is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
669
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
670
                self.check_memory()
671
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
672
673
674

            self.last_batch = batch

675
676
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
677
        if self.attn_tp_rank == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
678
679
            recv_reqs = []

680
681
682
683
684
            while True:
                try:
                    recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
685
                recv_reqs.append(recv_req)
686
687
688
689
690
691
692

            while True:
                try:
                    recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
                recv_reqs.append(recv_rpc)
Lianmin Zheng's avatar
Lianmin Zheng committed
693
694
        else:
            recv_reqs = None
695

696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
                work_reqs = broadcast_pyobj(
                    work_reqs,
                    self.attn_tp_rank,
                    self.attn_tp_cpu_group,
                    src=attn_tp_rank_0,
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
                    control_reqs, self.tp_rank, self.tp_cpu_group
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
730
            recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
731
732
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
733
    def process_input_requests(self, recv_reqs: List):
734
        for recv_req in recv_reqs:
735
736
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
Lianmin Zheng's avatar
Lianmin Zheng committed
737
                self.chunked_req is not None or not self.running_batch.is_empty()
738
739
740
741
            ):
                self.return_health_check_ct += 1
                continue

742
            output = self._request_dispatcher(recv_req)
743
            if output is not None:
744
745
746
747
748
                if isinstance(output, RpcReqOutput):
                    if self.recv_from_rpc is not None:
                        self.recv_from_rpc.send_pyobj(output)
                else:
                    self.send_to_tokenizer.send_pyobj(output)
749
750
751
752
753

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
754
        # Create a new request
755
756
757
758
759
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
760
761
762
763
764
765
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

766
767
768
769
770
771
772
773
774
775
776
777
778
            # Handle custom logit processor passed to the request
            custom_logit_processor = recv_req.custom_logit_processor
            if (
                not self.server_args.enable_custom_logit_processor
                and custom_logit_processor is not None
            ):
                logger.warning(
                    "The SGLang server is not configured to enable custom logit processor."
                    "The custom logit processor passed in will be ignored."
                    "Please set --enable-custom-logits-processor to enable this feature."
                )
                custom_logit_processor = None

779
780
781
782
783
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
784
785
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
786
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
787
                stream=recv_req.stream,
788
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
789
                input_embeds=recv_req.input_embeds,
790
                custom_logit_processor=custom_logit_processor,
791
                return_hidden_states=recv_req.return_hidden_states,
792
                eos_token_ids=self.model_config.hf_eos_token_id,
793
794
                bootstrap_host=recv_req.bootstrap_host,
                bootstrap_room=recv_req.bootstrap_room,
795
796
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
797

798
799
800
801
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
802
                req.finished_reason = FINISH_ABORT(
803
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
804
                )
805
                self._add_request_to_queue(req)
806
807
                return
        else:
808
809
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
810
            req = session.create_req(recv_req, self.tokenizer)
811
            if isinstance(req.finished_reason, FINISH_ABORT):
812
                self._add_request_to_queue(req)
813
                return
814

815
        # Handle multimodal inputs
Mick's avatar
Mick committed
816
817
        if recv_req.mm_inputs is not None:
            image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
818
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
819
            req.origin_input_ids = self.pad_input_ids_func(
820
                req.origin_input_ids, image_inputs
821
            )
822
            req.extend_image_inputs(image_inputs)
823

824
            if len(req.origin_input_ids) >= self.max_req_input_len:
825
                error_msg = (
826
                    "Multimodal prompt is too long after expanding multimodal tokens. "
827
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
828
                )
829
                logger.error(error_msg)
830
                req.origin_input_ids = [0]
Mick's avatar
Mick committed
831
                req.multimodal_inputs = None
832
                req.sampling_params.max_new_tokens = 0
833
                req.finished_reason = FINISH_ABORT(
834
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
835
                )
836
                self._add_request_to_queue(req)
837
838
                return

839
840
841
842
843
844
845
        # Validate prompts length
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
846
847
            req.origin_input_ids = [0]
            req.sampling_params.max_new_tokens = 0
848
            self._add_request_to_queue(req)
849
            return
850

851
        # Copy more attributes
852
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
853
854
855
856
857
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(req.origin_input_ids) - 1
        else:
            req.logprob_start_len = recv_req.logprob_start_len

858
859
860
861
862
863
864
865
866
867
        if req.logprob_start_len >= len(req.origin_input_ids):
            req.finished_reason = FINISH_ABORT(
                f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
                HTTPStatus.BAD_REQUEST,
                "BadRequestError",
            )
            req.logprob_start_len = len(req.origin_input_ids) - 1
            self._add_request_to_queue(req)
            return

868
869
870
871
872
873
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
874
            self.max_req_len - len(req.origin_input_ids) - 1,
875
876
        )

877
878
879
880
881
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
882
            or req.sampling_params.ebnf is not None
883
            or req.sampling_params.structural_tag is not None
884
885
886
887
888
889
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
890
891
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
892
893
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
894
895
896
897
898
899
900

            req.grammar = self.grammar_backend.get_cached_value(key)
            if not req.grammar:
                req.grammar = self.grammar_backend.get_future_value(key)
                add_to_grammar_queue = True

        if add_to_grammar_queue:
901
902
            self.grammar_queue.append(req)
        else:
903
904
905
            self._add_request_to_queue(req)

    def _add_request_to_queue(self, req: Req):
906
        req.queue_time_start = time.time()
Byron Hsu's avatar
Byron Hsu committed
907
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
Liangsheng Yin's avatar
Liangsheng Yin committed
908
            self.disagg_prefill_bootstrap_queue.add(req)
Byron Hsu's avatar
Byron Hsu committed
909
910
911
912
913
914
915
916
917
918
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.add(req)
        else:
            self.waiting_queue.append(req)

    def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
        if self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.extend(reqs)
        else:
            self.waiting_queue.extend(reqs)
919
920
921

    def handle_embedding_request(
        self,
922
        recv_req: TokenizedEmbeddingReqInput,
923
924
925
926
927
928
929
930
931
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

932
933
        # Handle multimodal inputs
        if recv_req.image_inputs is not None:
Mick's avatar
Mick committed
934
            image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
935
936
937
938
939
940
941
942
943
944
945
946
947
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
            req.origin_input_ids = self.pad_input_ids_func(
                req.origin_input_ids, image_inputs
            )
            req.extend_image_inputs(image_inputs)

            if len(req.origin_input_ids) >= self.max_req_input_len:
                error_msg = (
                    "Multimodal prompt is too long after expanding multimodal tokens. "
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                )
                logger.error(error_msg)
                req.origin_input_ids = [0]
Mick's avatar
Mick committed
948
                req.multimodal_inputs = None
949
950
951
952
                req.sampling_params.max_new_tokens = 0
                req.finished_reason = FINISH_ABORT(
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
                )
953
                req.queue_time_start = time.time()
954
955
956
                self.waiting_queue.append(req)
                return

957
        # Validate prompts length
958
        error_msg = validate_input_length(
959
960
961
962
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
963
        if error_msg:
964
            self._add_request_to_queue(req)
965
            return
966

967
968
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
969
        self._add_request_to_queue(req)
970

971
972
973
974
    def log_prefill_stats(
        self,
        adder: PrefillAdder,
        can_run_list: List[Req],
975
        running_bs: int,
976
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
977
978
979
980
981
        gap_latency = time.time() - self.last_prefill_stats_tic
        self.last_prefill_stats_tic = time.time()
        self.last_input_throughput = self.num_prefill_tokens / gap_latency
        self.num_prefill_tokens = 0

982
        num_used = self.max_total_num_tokens - (
983
984
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
985
        )
986
987
988
        self._largest_prefill_len = max(
            self._largest_prefill_len, adder.log_input_tokens
        )
989

990
        num_new_seq = len(can_run_list)
991
        f = (
992
            f"Prefill batch. "
993
            f"#new-seq: {num_new_seq}, "
994
995
996
997
998
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
999
1000
1001
1002
1003
1004
1005
1006

        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
            f += f"#queue-req: {len(self.waiting_queue)}, "
            f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)} "
        else:
            f += f"#queue-req: {len(self.waiting_queue)}"

1007
        logger.info(f)
1008
1009

        if self.enable_metrics:
1010
1011
1012
            cache_hit_rate = adder.log_hit_tokens / (
                adder.log_input_tokens + adder.log_hit_tokens
            )
1013
1014
1015
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
1016
1017
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.stats.cache_hit_rate = cache_hit_rate
1018
1019
1020
1021
1022
1023

            total_queue_latency = 0
            for req in can_run_list:
                total_queue_latency += req.queue_time_end - req.queue_time_start
            self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq

1024
1025
1026
            self.metrics_collector.log_stats(self.stats)

    def log_decode_stats(self):
1027
1028
1029
1030
        gap_latency = time.time() - self.last_decode_stats_tic
        self.last_decode_stats_tic = time.time()
        self.last_gen_throughput = self.num_generated_tokens / gap_latency
        self.num_generated_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1031
        num_running_reqs = len(self.running_batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1032
        num_used = self.max_total_num_tokens - (
1033
1034
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1035
        )
1036
1037
1038
1039
1040

        if RECORD_STEP_TIME:
            self.step_time_dict[num_running_reqs].append(
                gap_latency / self.server_args.decode_log_interval
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1041

Liangsheng Yin's avatar
Liangsheng Yin committed
1042
1043
1044
1045
1046
1047
1048
        msg = (
            f"Decode batch. "
            f"#running-req: {num_running_reqs}, "
            f"#token: {num_used}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
        )

1049
        if self.spec_algorithm.is_none():
1050
            spec_accept_length = 0
1051
        else:
1052
            spec_accept_length = (
1053
1054
                self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
            )
1055
1056
            self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
            self.cum_spec_accept_count += self.spec_num_total_forward_ct
1057
            self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
1058
1059
1060
1061
1062
1063
1064
1065
1066
            msg += f"accept len: {spec_accept_length:.2f}, "

        if self.disaggregation_mode == DisaggregationMode.DECODE:
            msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "

        msg += (
            f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
            f"#queue-req: {len(self.waiting_queue)}"
        )
1067
1068

        logger.info(msg)
1069
1070
1071
1072
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
1073
1074
            self.stats.cache_hit_rate = 0.0
            self.stats.gen_throughput = self.last_gen_throughput
1075
            self.stats.num_queue_reqs = len(self.waiting_queue)
1076
            self.stats.spec_accept_length = spec_accept_length
1077
1078
            self.metrics_collector.log_stats(self.stats)

Lianmin Zheng's avatar
Lianmin Zheng committed
1079
1080
    def check_memory(self):
        available_size = (
1081
1082
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1083
        )
1084
1085
1086
1087
1088
1089
1090
        protected_size = self.tree_cache.protected_size()
        memory_leak = available_size != (
            self.max_total_num_tokens
            if not self.enable_hierarchical_cache
            else self.max_total_num_tokens - protected_size
        )
        if memory_leak:
1091
            msg = (
1092
                "token_to_kv_pool_allocator memory leak detected! "
1093
                f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1094
1095
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1096
            )
1097
1098
1099
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1100
1101

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
1102
            msg = (
1103
                "req_to_token_pool memory leak detected!"
1104
1105
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1106
            )
1107
1108
1109
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1110

1111
1112
1113
1114
1115
1116
1117
        if (
            self.enable_metrics
            and self.attn_tp_rank == 0
            and time.time() > self.metrics_collector.last_log_time + 30
        ):
            # During idle time, also collect metrics every 30 seconds.
            num_used = self.max_total_num_tokens - (
1118
                self.token_to_kv_pool_allocator.available_size()
1119
1120
                + self.tree_cache.evictable_size()
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1121
            num_running_reqs = len(self.running_batch.reqs)
1122
1123
1124
1125
1126
1127
1128
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.metrics_collector.log_stats(self.stats)

1129
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1130
        # Merge the prefill batch into the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1131
        if self.last_batch and self.last_batch.forward_mode.is_extend():
1132
1133
1134
1135
1136
1137
1138
            if self.chunked_req:
                # Move the chunked request out of the batch so that we can merge
                # only finished requests to running_batch.
                self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
                self.tree_cache.cache_unfinished_req(self.chunked_req)
                # chunked request keeps its rid but will get a new req_pool_idx
                self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
Lianmin Zheng's avatar
Lianmin Zheng committed
1139
                self.running_batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1140

1141
            # Filter batch
1142
            last_bs = self.last_batch.batch_size()
1143
            self.last_batch.filter_batch()
1144
            if self.last_batch.batch_size() < last_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1145
                self.running_batch.batch_is_full = False
1146

1147
            # Merge the new batch into the running batch
1148
            if not self.last_batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1149
                if self.running_batch.is_empty():
1150
1151
                    self.running_batch = self.last_batch
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1152
                    # Merge running_batch with prefill batch
1153
                    self.running_batch.merge_batch(self.last_batch)
1154

1155
1156
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
1157
1158
1159
1160
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
1161
            if not self.running_batch.is_empty():
1162
                self.running_batch = self.update_running_batch(self.running_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1163
1164
1165
                ret = self.running_batch if not self.running_batch.is_empty() else None
            else:
                ret = None
1166

1167
        # Handle DP attention
1168
        if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
Lianmin Zheng's avatar
Lianmin Zheng committed
1169
            ret, _ = self.prepare_dp_attn_batch(ret)
1170
1171

        return ret
1172

Lianmin Zheng's avatar
Lianmin Zheng committed
1173
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
1174
        # Check if the grammar is ready in the grammar queue
1175
        if self.grammar_queue:
1176
            self.move_ready_grammar_requests()
1177

Lianmin Zheng's avatar
Lianmin Zheng committed
1178
1179
        # Handle the cases where prefill is not allowed
        if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1180
            self.running_batch.batch_is_full or len(self.waiting_queue) == 0
1181
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1182
1183
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1184
        running_bs = len(self.running_batch.reqs)
1185
        if running_bs >= self.max_running_requests:
Lianmin Zheng's avatar
Lianmin Zheng committed
1186
            self.running_batch.batch_is_full = True
1187
1188
            return None

1189
1190
1191
1192
1193
        if self.enable_hierarchical_cache:
            # check for completion of hierarchical cache activities to release memory
            self.tree_cache.writing_check()
            self.tree_cache.loading_check()

1194
1195
1196
        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
1197
        # Prefill policy
1198
1199
        adder = PrefillAdder(
            self.tree_cache,
1200
            self.token_to_kv_pool_allocator,
1201
1202
1203
1204
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1205
            running_bs if self.is_mixed_chunk else 0,
1206
1207
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1208
        if self.chunked_req is not None:
1209
1210
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1211

Lianmin Zheng's avatar
Lianmin Zheng committed
1212
        if self.lora_paths:
Lianmin Zheng's avatar
Lianmin Zheng committed
1213
1214
            lora_set = set([req.lora_path for req in self.running_batch.reqs])

1215
        # Get requests from the waiting queue to a new prefill batch
1216
1217
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1218
                self.lora_paths
1219
1220
1221
1222
1223
1224
1225
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1226
                self.running_batch.batch_is_full = True
1227
1228
                break

1229
            if running_bs + len(adder.can_run_list) >= self.max_running_requests:
Lianmin Zheng's avatar
Lianmin Zheng committed
1230
                self.running_batch.batch_is_full = True
1231
                break
1232

1233
1234
1235
1236
            req.init_next_round_input(
                None if prefix_computed else self.tree_cache,
                self.enable_hierarchical_cache,
            )
1237

1238
1239
1240
            res = adder.add_one_req(
                req, self.chunked_req, self.enable_hierarchical_cache
            )
1241
1242
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1243
1244
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
Lianmin Zheng's avatar
Lianmin Zheng committed
1245
1246
1247
                        self.running_batch.batch_is_full = len(
                            adder.can_run_list
                        ) > 0 or (
1248
1249
1250
1251
                            self.running_batch is not None
                            and not self.running_batch.is_empty()
                        )
                    else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1252
                        self.running_batch.batch_is_full = True
1253
1254
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1255
        # Update waiting queue
1256
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1257
1258
        if len(can_run_list) == 0:
            return None
1259
1260
1261
1262
1263
1264

        if self.enable_metrics:
            # only record queue time when enable_metrics is True to avoid overhead
            for req in can_run_list:
                req.queue_time_end = time.time()

Lianmin Zheng's avatar
Lianmin Zheng committed
1265
1266
1267
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1268

1269
        if self.enable_hierarchical_cache:
1270
            self.tree_cache.ready_to_load_cache()
1271

1272
1273
1274
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1275

1276
1277
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1278

1279
        # Print stats
1280
        if self.attn_tp_rank == 0:
1281
            self.log_prefill_stats(adder, can_run_list, running_bs)
1282

Lianmin Zheng's avatar
Lianmin Zheng committed
1283
        # Create a new batch
1284
1285
1286
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
1287
            self.token_to_kv_pool_allocator,
1288
            self.tree_cache,
1289
            self.model_config,
1290
            self.enable_overlap,
1291
            self.spec_algorithm,
1292
            self.server_args.enable_custom_logit_processor,
1293
        )
1294
        new_batch.prepare_for_extend()
1295

Lianmin Zheng's avatar
Lianmin Zheng committed
1296
        # Mixed-style chunked prefill
1297
1298
        if (
            self.is_mixed_chunk
Lianmin Zheng's avatar
Lianmin Zheng committed
1299
            and not self.running_batch.is_empty()
1300
1301
1302
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1303
1304
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1305
                self.running_batch.prepare_for_decode()
1306
1307
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1308
1309
1310
            self.running_batch = ScheduleBatch(
                reqs=[], batch_is_full=self.running_batch.batch_is_full
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1311
1312
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1313
1314
1315

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1316
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1317
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1318
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1319

1320
1321
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1322
1323
            batch.batch_is_full = False
            return batch
1324

Lianmin Zheng's avatar
Lianmin Zheng committed
1325
        # Check if decode out of memory
1326
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1327
            TEST_RETRACT and batch.batch_size() > 10
1328
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1329
1330
            old_ratio = self.new_token_ratio

1331
            retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
Lianmin Zheng's avatar
Lianmin Zheng committed
1332
            self.new_token_ratio = new_token_ratio
1333

Lianmin Zheng's avatar
Lianmin Zheng committed
1334
1335
1336
1337
1338
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
1339
            self._extend_requests_to_queue(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1340
1341
        else:
            self.new_token_ratio = max(
1342
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1343
1344
1345
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1346
        if batch.batch_size() < initial_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1347
            batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1348
1349

        # Update batch tensors
1350
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1351
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1352

1353
1354
1355
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1356
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1357
1358
        self.forward_ct += 1

1359
1360
1361
1362
1363
1364
1365
        # Check profiler
        if (
            self.profiler_target_forward_ct
            and self.profiler_target_forward_ct <= self.forward_ct
        ):
            self.stop_profile()

1366
        # Run forward
1367
        if self.is_generation:
1368
1369
1370
1371
1372
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
                logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
                    model_worker_batch
                )
1373
                bid = model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
1374
            else:
1375
1376
1377
                (
                    logits_output,
                    next_token_ids,
1378
                    bid,
1379
1380
1381
1382
1383
1384
1385
                    num_accepted_tokens,
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
                self.spec_num_total_accepted_tokens += (
                    num_accepted_tokens + batch.batch_size()
                )
                self.spec_num_total_forward_ct += batch.batch_size()
                self.num_generated_tokens += num_accepted_tokens
1386
            batch.output_ids = next_token_ids
1387

1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
            if batch.return_logprob:
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_input_len_per_req = None
                extend_logprob_start_len_per_req = None

1400
1401
1402
            ret = GenerationBatchResult(
                logits_output=logits_output,
                next_token_ids=next_token_ids,
1403
1404
                extend_input_len_per_req=extend_input_len_per_req,
                extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1405
                bid=bid,
1406
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1407
1408
1409
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1410
1411
1412
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1413
        return ret
Chayenne's avatar
Chayenne committed
1414

1415
1416
1417
1418
1419
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1420
1421
        if batch.forward_mode.is_decode():
            self.process_batch_result_decode(batch, result)
1422
        elif batch.forward_mode.is_extend():
Lianmin Zheng's avatar
Lianmin Zheng committed
1423
            self.process_batch_result_prefill(batch, result)
1424
1425
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1426
                self.tp_worker.resolve_batch_result(result.bid)
1427
1428
1429
1430
                if batch.next_batch_sampling_info:
                    batch.next_batch_sampling_info.update_regex_vocab_mask()
                    self.current_stream.synchronize()
                    batch.next_batch_sampling_info.sampling_info_done.set()
1431
1432
        elif batch.forward_mode.is_dummy_first():
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1433
            self.current_stream.synchronize()
1434
            batch.next_batch_sampling_info.sampling_info_done.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
1435

1436
1437
1438
1439
1440
1441
1442
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

1443
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
        return self.prepare_dp_attn_batch_raw(
            local_batch,
            dp_size=self.server_args.dp_size,
            attn_tp_size=self.attn_tp_size,
            tp_cpu_group=self.tp_cpu_group,
            get_idle_batch=self.get_idle_batch,
            disable_cuda_graph=self.server_args.disable_cuda_graph,
            spec_algorithm=self.spec_algorithm,
            speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
        )

    @staticmethod
    def prepare_dp_attn_batch_raw(
        local_batch: ScheduleBatch,
        dp_size,
        attn_tp_size: int,
        tp_cpu_group,
        get_idle_batch,
        disable_cuda_graph: bool,
        spec_algorithm,
        speculative_num_draft_tokens,
    ):
1466
1467
1468
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1469
            global_num_tokens_for_logprob = 0
1470
1471
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
1472
1473
            if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
                num_tokens = num_tokens * speculative_num_draft_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
1474
            global_num_tokens_for_logprob = num_tokens
1475
1476
        else:
            num_tokens = local_batch.extend_num_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
            global_num_tokens_for_logprob = sum(
                [
                    # We should have at least 1 token for sample in every case.
                    max(extend_len - logprob_start_len, 1)
                    for logprob_start_len, extend_len in zip(
                        local_batch.extend_logprob_start_lens, local_batch.extend_lens
                    )
                ]
            )

        if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
            can_cuda_graph = 1
        else:
            can_cuda_graph = 0

1492
        if not spec_algorithm.is_none():
Lianmin Zheng's avatar
Lianmin Zheng committed
1493
1494
1495
            # TODO(sang): Support cuda graph when idle batch is there.
            if local_batch is None or local_batch.forward_mode.is_idle():
                can_cuda_graph = 0
1496

Lianmin Zheng's avatar
Lianmin Zheng committed
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
        is_extend_in_batch = (
            local_batch.forward_mode.is_extend() if local_batch else False
        )
        local_info = torch.tensor(
            [
                num_tokens,
                can_cuda_graph,
                global_num_tokens_for_logprob,
                is_extend_in_batch,
            ],
            dtype=torch.int64,
        )
        global_info = torch.empty(
1510
            (dp_size, attn_tp_size, 4),
Lianmin Zheng's avatar
Lianmin Zheng committed
1511
1512
            dtype=torch.int64,
        )
1513
        torch.distributed.all_gather_into_tensor(
Lianmin Zheng's avatar
Lianmin Zheng committed
1514
1515
            global_info.flatten(),
            local_info,
1516
            group=tp_cpu_group,
1517
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1518
1519
1520
1521
        global_num_tokens = global_info[:, 0, 0].tolist()
        can_cuda_graph = min(global_info[:, 0, 1].tolist())
        global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
        is_extend_in_batch = global_info[:, 0, 3].tolist()
1522

Lianmin Zheng's avatar
Lianmin Zheng committed
1523
        if local_batch is None and max(global_num_tokens) > 0:
1524
            local_batch = get_idle_batch()
1525
1526

        if local_batch is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1527
1528
            local_batch.global_num_tokens = global_num_tokens
            local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
1529
1530

            # Check forward mode for cuda graph
1531
            if not disable_cuda_graph:
Lianmin Zheng's avatar
Lianmin Zheng committed
1532
                local_batch.can_run_dp_cuda_graph = can_cuda_graph
1533

Lianmin Zheng's avatar
Lianmin Zheng committed
1534
        return local_batch, any(is_extend_in_batch)
1535
1536
1537
1538
1539

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
1540
            self.token_to_kv_pool_allocator,
1541
1542
1543
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
1544
            self.spec_algorithm,
1545
            self.server_args.enable_custom_logit_processor,
1546
1547
1548
1549
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
        num_ready_reqs = 0
        for req in self.grammar_queue:
            try:
                req.grammar = req.grammar.result(timeout=0.05)
                num_ready_reqs += 1
            except futures._base.TimeoutError:
                break

1560
        if self.server_args.enable_dp_attention:
1561
1562
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
1563
        else:
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
            tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
            num_ready_reqs_max = tensor.item()
            for i in range(num_ready_reqs, num_ready_reqs_max):
                self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
            num_ready_reqs = num_ready_reqs_max
1577

1578
        self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1579
1580
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
    def watchdog_thread(self):
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
        self.watchdog_last_forward_ct = 0
        self.watchdog_last_time = time.time()

        while True:
            current = time.time()
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if current > self.watchdog_last_time + self.watchdog_timeout:
                        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)

        # Print batch size and memory pool info to check whether there are de-sync issues.
        logger.error(
            f"{self.cur_batch.batch_size()=}, "
            f"{self.cur_batch.reqs=}, "
            f"{self.token_to_kv_pool_allocator.available_size()=}, "
            f"{self.tree_cache.evictable_size()=}, "
        )
        # Wait for some time so that the parent process can print the error.
        pyspy_dump_schedulers()
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
        time.sleep(5)
        self.parent_process.send_signal(signal.SIGQUIT)

1612
1613
1614
    def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
        success = self.flush_cache()
        return FlushCacheReqOutput(success=success)
1615

1616
    def flush_cache(self):
1617
        """Flush the memory pool and cache."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1618
        if len(self.waiting_queue) == 0 and self.running_batch.is_empty():
1619
1620
            self.cur_batch = None
            self.last_batch = None
1621
            self.tree_cache.reset()
1622
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1623
                self.grammar_backend.reset()
1624
            self.req_to_token_pool.clear()
1625
            self.token_to_kv_pool_allocator.clear()
1626
1627
1628

            if not self.spec_algorithm.is_none():
                self.draft_worker.model_runner.req_to_token_pool.clear()
1629
                self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
1630
1631
1632
1633
1634

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
1635
1636
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
1637
1638
1639
1640
1641
1642
1643
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
1644
                f"#running-req: {len(self.running_batch.reqs)}"
1645
1646
1647
1648
            )
            if_success = False
        return if_success

1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )

        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
        return GetInternalStateReqOutput(
            internal_state=ret,
        )

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
            logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
    def handle_rpc_request(self, recv_req: RpcReqInput):
        # Handle RPC requests
        logger.info(
            f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
        )

        success = True
        exec = None
        try:
            func = getattr(self, recv_req.method)
            func(recv_req.parameters)
        except Exception as e:
            success = False
            exec = e
            logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")

        barrier()
        return RpcReqOutput(success, "" if not exec else str(exec))

    def save_remote_model(self, params):
        url = params["url"]

1714
        worker = self.tp_worker.worker
1715
1716
1717
1718

        worker.model_runner.save_remote_model(url)

    def save_sharded_model(self, params):
1719
        worker = self.tp_worker.worker
1720
1721
1722
1723
1724
1725
1726

        worker.model_runner.save_sharded_model(
            path=params["path"],
            pattern=params["pattern"],
            max_size=params["max_size"],
        )

1727
1728
    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
Lianmin Zheng's avatar
Lianmin Zheng committed
1729
        to_del = []
1730
        for i, req in enumerate(self.waiting_queue):
Lianmin Zheng's avatar
Lianmin Zheng committed
1731
1732
            if req.rid.startswith(recv_req.rid):
                to_del.append(i)
1733
1734
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1735
1736
1737
        # Sort in reverse order to avoid index issues when deleting
        for i in sorted(to_del, reverse=True):
            req = self.waiting_queue.pop(i)
1738
1739
            logger.debug(f"Abort queued request. {req.rid=}")
            return
1740
1741

        # Delete requests in the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1742
1743
1744
1745
1746
        for req in self.running_batch.reqs:
            if req.rid.startswith(recv_req.rid) and not req.finished():
                logger.debug(f"Abort running request. {req.rid=}")
                req.to_abort = True
                return
1747

1748
1749
1750
    def _pause_engine(self) -> Tuple[List[Req], int]:
        raise NotImplementedError()

Chayenne's avatar
Chayenne committed
1751
1752
1753
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
1754
1755
1756
1757
1758
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
1759
        return UpdateWeightFromDiskReqOutput(success, message, 0)
1760

1761
1762
1763
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
1764
        return InitWeightsUpdateGroupReqOutput(success, message)
1765
1766

    def update_weights_from_distributed(
1767
1768
1769
        self,
        recv_req: UpdateWeightsFromDistributedReqInput,
    ) -> Tuple[bool, str]:
1770
1771
1772
1773
1774
1775
1776
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
1777
        return UpdateWeightsFromDistributedReqOutput(success, message)
1778

1779
1780
1781
1782
1783
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        """Update the online model parameter from tensors."""
        success, message = self.tp_worker.update_weights_from_tensor(recv_req)
        # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
        if success:
1784
1785
1786
            if recv_req.flush_cache:
                flash_cache_success = self.flush_cache()
                assert flash_cache_success, "Cache flush failed after updating weights"
1787
1788
        else:
            logger.error(message)
1789
        return UpdateWeightsFromTensorReqOutput(success, message)
1790

1791
1792
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
1793
        return GetWeightsByNameReqOutput(parameter)
1794

1795
    def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
1796
1797
1798
        self.memory_saver_adapter.check_validity(
            caller_name="release_memory_occupation"
        )
1799
1800
1801
1802
1803
        self.stashed_model_static_state = _export_static_state(
            self.tp_worker.worker.model_runner.model
        )
        self.memory_saver_adapter.pause()
        self.flush_cache()
1804
        return ReleaseMemoryOccupationReqOutput()
1805

1806
    def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
1807
        self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
1808
1809
1810
1811
1812
        self.memory_saver_adapter.resume()
        _import_static_state(
            self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
        )
        del self.stashed_model_static_state
1813
1814
1815
        return ResumeMemoryOccupationReqOutput()

    def profile(self, recv_req: ProfileReq):
1816
1817
        if recv_req.type == ProfileReqType.START_PROFILE:
            return self.start_profile(
1818
1819
1820
1821
1822
                recv_req.output_dir,
                recv_req.num_steps,
                recv_req.activities,
                recv_req.with_stack,
                recv_req.record_shapes,
1823
                recv_req.profile_id,
1824
            )
1825
        else:
1826
1827
1828
1829
1830
1831
1832
            return self.stop_profile()

    def start_profile(
        self,
        output_dir: Optional[str],
        num_steps: Optional[int],
        activities: Optional[List[str]],
1833
1834
        with_stack: Optional[bool],
        record_shapes: Optional[bool],
1835
        profile_id: Optional[str],
1836
    ) -> None:
1837
        if self.profiler_activities:
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
            return ProfileReqOutput(
                success=False,
                message="Profiling is already in progress. Call /stop_profile first.",
            )

        if output_dir is None:
            output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
        if activities is None:
            activities = ["CPU", "GPU"]

        self.torch_profiler_output_dir = output_dir
1849
        self.profiler_activities = activities
1850
        self.profiler_id = profile_id
1851
        logger.info(
1852
            "Profiling starts. Traces will be saved to: %s (with id %s)",
1853
            self.torch_profiler_output_dir,
1854
            self.profiler_id,
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
        )

        activity_map = {
            "CPU": torch.profiler.ProfilerActivity.CPU,
            "GPU": torch.profiler.ProfilerActivity.CUDA,
        }
        torchprof_activities = [
            activity_map[a] for a in activities if a in activity_map
        ]

        if torchprof_activities:
            self.torch_profiler = torch.profiler.profile(
                activities=torchprof_activities,
1868
1869
                with_stack=with_stack if with_stack is not None else True,
                record_shapes=record_shapes if record_shapes is not None else False,
1870
1871
1872
1873
1874
            )
            self.torch_profiler.start()

        if "MEM" in activities:
            torch.cuda.memory._record_memory_history(max_entries=100000)
1875

1876
1877
1878
        if "CUDA_PROFILER" in activities:
            torch.cuda.cudart().cudaProfilerStart()

1879
1880
1881
1882
1883
1884
        if num_steps:
            self.profiler_target_forward_ct = self.forward_ct + num_steps
            # The caller will be notified when reaching profiler_target_forward_ct
        else:
            self.profiler_target_forward_ct = None
            return ProfileReqOutput(success=True, message="Succeeded")
1885
1886

    def stop_profile(self) -> None:
1887
        if self.profiler_activities is None:
1888
1889
1890
1891
1892
1893
1894
1895
            return

        logger.info("Stop profiling...")
        if self.torch_profiler is not None:
            self.torch_profiler.stop()
            self.torch_profiler.export_chrome_trace(
                os.path.join(
                    self.torch_profiler_output_dir,
1896
                    self.profiler_id + f"-TP-{self.tp_rank}" + ".trace.json.gz",
1897
1898
1899
                )
            )

1900
        if "MEM" in self.profiler_activities:
1901
            memory_profile_path = os.path.join(
1902
                self.torch_profiler_output_dir,
1903
                self.profiler_id + f"-TP-{self.tp_rank}-memory" + ".pickle",
1904
1905
1906
1907
            )
            torch.cuda.memory._dump_snapshot(memory_profile_path)
            torch.cuda.memory._record_memory_history(enabled=None)

1908
1909
1910
        if "CUDA_PROFILER" in self.profiler_activities:
            torch.cuda.cudart().cudaProfilerStop()

1911
1912
1913
        logger.info(
            "Profiling done. Traces are saved to: %s",
            self.torch_profiler_output_dir,
1914
        )
1915
1916
        self.torch_profiler = None
        self.torch_profiler_output_dir = None
1917
        self.profiler_activities = None
1918
1919
1920
1921
1922

        if self.profiler_target_forward_ct:
            self.send_to_tokenizer.send_pyobj(
                ProfileReqOutput(success=True, message="Succeeded.")
            )
1923

1924
1925
1926
1927
1928
1929
1930
1931
1932
    def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
        if recv_req == ExpertDistributionReq.START_RECORD:
            expert_distribution_recorder.start_record()
        elif recv_req == ExpertDistributionReq.STOP_RECORD:
            expert_distribution_recorder.stop_record()
        elif recv_req == ExpertDistributionReq.DUMP_RECORD:
            expert_distribution_recorder.dump_record()
        else:
            raise ValueError("Unrecognized ExpertDistributionReq value")
1933
        return ExpertDistributionReqOutput()
1934

1935
    def open_session(self, recv_req: OpenSessionReqInput):
1936
1937
1938
1939
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
1940
            return OpenSessionReqOutput(session_id, False)
1941
        elif session_id is None:
1942
            logger.warning("session id is None, cannot open.")
1943
            return OpenSessionReqOutput(session_id, False)
1944
1945
1946
1947
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
1948
            return OpenSessionReqOutput(session_id, True)
1949
1950
1951
1952
1953
1954
1955
1956
1957

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

1958

1959
1960
1961
1962
def is_health_check_generate_req(recv_req):
    return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")


1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
def _export_static_state(model):
    return dict(
        buffers=[
            (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
        ]
    )


def _import_static_state(model, static_params):
    self_named_buffers = dict(model.named_buffers())
    for name, tensor in static_params["buffers"]:
        self_named_buffers[name][...] = tensor


1977
1978
1979
1980
1981
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
1982
    dp_rank: Optional[int],
1983
    pipe_writer,
1984
):
1985
1986
1987
1988
1989
1990
    # Generate the prefix
    if dp_rank is None:
        prefix = f" TP{tp_rank}"
    else:
        prefix = f" DP{dp_rank} TP{tp_rank}"

1991
    # Config the process
1992
    kill_itself_when_parent_died()
1993
    setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
1994
    faulthandler.enable()
1995
    parent_process = psutil.Process().parent()
1996

1997
1998
1999
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
2000

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
2001
    # Configure the logger
2002
    configure_logger(server_args, prefix=prefix)
2003
    suppress_other_loggers()
2004

2005
    # Set cpu affinity to this gpu process
2006
2007
2008
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

2009
    # Create a scheduler and run the event loop
2010
    try:
2011
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
2012
        pipe_writer.send(
Mick's avatar
Mick committed
2013
2014
2015
2016
2017
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2018
        )
Byron Hsu's avatar
Byron Hsu committed
2019
2020
2021
2022
2023
2024
2025
2026
        disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode

        if disaggregation_mode == DisaggregationMode.NULL:
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap()
            else:
                scheduler.event_loop_normal()
        elif disaggregation_mode == DisaggregationMode.PREFILL:
2027
2028
2029
2030
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_prefill()
            else:
                scheduler.event_loop_normal_disagg_prefill()
Byron Hsu's avatar
Byron Hsu committed
2031
        elif disaggregation_mode == DisaggregationMode.DECODE:
2032
2033
2034
2035
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_decode()
            else:
                scheduler.event_loop_normal_disagg_decode()
Byron Hsu's avatar
Byron Hsu committed
2036

2037
    except Exception:
2038
2039
2040
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)