scheduler.py 78 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
23
import time
import warnings
24
from collections import defaultdict, deque
Lianmin Zheng's avatar
Lianmin Zheng committed
25
from concurrent import futures
26
from dataclasses import dataclass
27
from http import HTTPStatus
28
from types import SimpleNamespace
29
from typing import Dict, List, Optional, Tuple, Union
30

31
import psutil
32
import setproctitle
33
import torch
34
import zmq
35
from torch.distributed import barrier
36

37
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
38
from sglang.srt.configs.model_config import ModelConfig
39
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
Byron Hsu's avatar
Byron Hsu committed
40
41
42
43
44
45
46
47
48
49
50
51
from sglang.srt.disaggregation.decode import (
    DecodePreallocQueue,
    DecodeTransferQueue,
    SchedulerDisaggregationDecodeMixin,
)
from sglang.srt.disaggregation.prefill import (
    PrefillBootstrapQueue,
    SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
    ReqToMetadataIdxAllocator,
52
    TransferBackend,
Byron Hsu's avatar
Byron Hsu committed
53
)
54
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
55
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
56
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
57
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
58
59
from sglang.srt.managers.io_struct import (
    AbortReq,
60
    CloseSessionReqInput,
61
    ExpertDistributionReq,
62
    ExpertDistributionReqOutput,
63
    FlushCacheReq,
64
65
    GetInternalStateReq,
    GetInternalStateReqOutput,
66
67
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
68
    HealthCheckOutput,
69
70
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
71
72
    OpenSessionReqInput,
    OpenSessionReqOutput,
73
    ProfileReq,
74
75
    ProfileReqOutput,
    ProfileReqType,
76
77
78
79
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
80
81
    RpcReqInput,
    RpcReqOutput,
82
83
    SetInternalStateReq,
    SetInternalStateReqOutput,
84
85
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
86
87
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
88
89
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
90
91
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
92
93
94
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
Mick's avatar
Mick committed
95
    MultimodalInputs,
96
97
    Req,
    ScheduleBatch,
98
    global_server_args_dict,
99
)
100
101
102
103
104
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
105
106
107
from sglang.srt.managers.scheduler_output_processor_mixin import (
    SchedulerOutputProcessorMixin,
)
108
from sglang.srt.managers.session_controller import Session
109
from sglang.srt.managers.tp_worker import TpModelWorker
110
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
111
from sglang.srt.managers.utils import validate_input_length
112
from sglang.srt.mem_cache.chunk_cache import ChunkCache
113
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
114
from sglang.srt.mem_cache.radix_cache import RadixCache
115
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
Mick's avatar
Mick committed
116
from sglang.srt.model_executor.forward_batch_info import ForwardMode
117
from sglang.srt.reasoning_parser import ReasoningParser
118
from sglang.srt.server_args import PortArgs, ServerArgs
119
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
120
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
121
from sglang.srt.utils import (
122
    DynamicGradMode,
123
124
    broadcast_pyobj,
    configure_logger,
125
    crash_on_warnings,
126
    get_bool_env_var,
127
    get_zmq_socket,
Lianmin Zheng's avatar
Lianmin Zheng committed
128
    kill_itself_when_parent_died,
129
    pyspy_dump_schedulers,
130
    set_gpu_proc_affinity,
131
132
133
    set_random_seed,
    suppress_other_loggers,
)
134
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
135

136
137
expert_distribution_recorder = ExpertDistributionRecorder()

138
139
logger = logging.getLogger(__name__)

140
# Test retract decode for debugging purposes
141
142
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
143

144

145
146
147
148
@dataclass
class GenerationBatchResult:
    logits_output: LogitsProcessorOutput
    next_token_ids: List[int]
149
150
    extend_input_len_per_req: List[int]
    extend_logprob_start_len_per_req: List[int]
151
152
153
154
155
156
157
158
159
    bid: int


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


Byron Hsu's avatar
Byron Hsu committed
160
161
162
163
164
class Scheduler(
    SchedulerOutputProcessorMixin,
    SchedulerDisaggregationDecodeMixin,
    SchedulerDisaggregationPrefillMixin,
):
165
166
167
168
169
170
171
172
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
173
        dp_rank: Optional[int],
174
175
    ):
        # Parse args
176
        self.server_args = server_args
177
178
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
179
180
181
        self.schedule_policy = server_args.schedule_policy
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
182
        self.enable_overlap = not server_args.disable_overlap_schedule
183
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
184
        self.enable_metrics = server_args.enable_metrics
185
        self.stream_interval = server_args.stream_interval
186
187
188
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
189
190
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
Lianmin Zheng's avatar
Lianmin Zheng committed
191
        self.page_size = server_args.page_size
192

193
        # Distributed rank info
194
195
196
197
198
199
200
201
202
203
        self.dp_size = server_args.dp_size
        self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

204
205
        # Init inter-process communication
        context = zmq.Context(2)
206
        if self.attn_tp_rank == 0:
207
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
208
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
209
            )
210
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
211
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
212
            )
213

214
            if server_args.skip_tokenizer_init:
215
                # Directly send to the TokenizerManager
216
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
217
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
218
219
                )
            else:
220
                # Send to the DetokenizerManager
221
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
222
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
223
                )
224
225
226
227

            self.recv_from_rpc = get_zmq_socket(
                context, zmq.DEALER, port_args.rpc_ipc_name, False
            )
228
        else:
229
            self.recv_from_tokenizer = None
230
            self.recv_from_rpc = None
231
232
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
233
234

        # Init tokenizer
235
        self.init_tokenizer()
236

237
238
239
240
241
242
243
244
245
        # Set reasoning_parser and think_end_id if --reasoning_parser is enabled
        if self.server_args.reasoning_parser and self.tokenizer:
            reasoning_parser = ReasoningParser(
                model_type=self.server_args.reasoning_parser, stream_reasoning=False
            )
            self.tokenizer.think_end_id = self.tokenizer.encode(
                reasoning_parser.detector.think_end_token, add_special_tokens=False
            )[0]

246
247
248
249
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
250
251
252
253
        if self.model_config.is_multimodal:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for multimodal models.")

254
        # Launch a tensor parallel worker
255
        if self.enable_overlap:
256
            TpWorkerClass = TpModelWorkerClient
257
258
        else:
            TpWorkerClass = TpModelWorker
259

260
        self.tp_worker = TpWorkerClass(
261
            server_args=server_args,
262
263
            gpu_id=gpu_id,
            tp_rank=tp_rank,
264
            dp_rank=dp_rank,
265
            nccl_port=port_args.nccl_port,
266
        )
267

268
        # Launch a draft worker for speculative decoding
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
        else:
            self.draft_worker = None

283
        # Get token and memory info from the model worker
284
285
286
287
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
288
            self.max_req_len,
289
290
            self.max_req_input_len,
            self.random_seed,
291
            self.device,
292
293
294
295
296
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
297
        self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
298
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
299
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
300
        global_server_args_dict.update(worker_global_server_args_dict)
301
        set_random_seed(self.random_seed)
302

303
304
305
        # Print debug info
        logger.info(
            f"max_total_num_tokens={self.max_total_num_tokens}, "
306
            f"chunked_prefill_size={server_args.chunked_prefill_size}, "
307
308
309
310
311
            f"max_prefill_tokens={self.max_prefill_tokens}, "
            f"max_running_requests={self.max_running_requests}, "
            f"context_len={self.model_config.context_len}"
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
312
        # Init memory pool and cache
313
        self.init_memory_pool_and_cache()
314
315
316

        # Init running status
        self.waiting_queue: List[Req] = []
317
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
318
        self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
319
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
320
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
321
        # The last forward batch
322
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
323
324
        self.forward_ct = 0
        self.forward_ct_decode = 0
325
        self.num_generated_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
326
        self.num_prefill_tokens = 0
327
        self.last_decode_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
328
        self.last_prefill_stats_tic = time.time()
329
        self.return_health_check_ct = 0
330
        self.current_stream = torch.get_device_module(self.device).current_stream()
331
332
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
333

334
        # Init session info
335
        self.sessions: Dict[str, Session] = {}
336
337
338

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
339
340
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
341
        self.chunked_req = None
342
343
344
345
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
346
        # Init the grammar backend for constrained generation
347
        self.grammar_queue: List[Req] = []
348
        if not server_args.skip_tokenizer_init:
349
350
351
            self.grammar_backend = create_grammar_backend(
                server_args, self.tokenizer, self.model_config.vocab_size
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
352
353
        else:
            self.grammar_backend = None
354

355
        # Init schedule policy and new token estimation
356
        self.policy = SchedulePolicy(
Lianmin Zheng's avatar
Lianmin Zheng committed
357
358
359
            self.schedule_policy,
            self.tree_cache,
            self.enable_hierarchical_cache,
360
        )
361
362
363
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
364
365
        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
366
367
            * server_args.schedule_conservativeness,
            1.0,
368
        )
369
370
371
372
373
374
375
376
377
378
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
379
380
381
382
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
383
        self.parent_process = psutil.Process().parent()
Lianmin Zheng's avatar
Lianmin Zheng committed
384

385
        # Init memory saver
386
387
388
389
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )

390
        # Init profiler
391
392
        self.torch_profiler = None
        self.torch_profiler_output_dir: Optional[str] = None
393
        self.profiler_activities: Optional[List[str]] = None
394
        self.profiler_id: Optional[str] = None
395
        self.profiler_target_forward_ct: Optional[int] = None
396

397
        # Init metrics stats
398
        self.init_metrics()
399

400
401
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
402
403
404
405
406
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
                (FlushCacheReq, self.flush_cache_wrapped),
                (AbortReq, self.abort_request),
407
408
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
409
410
411
412
413
414
415
416
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
417
418
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
419
                (ProfileReq, self.profile),
420
                (GetInternalStateReq, self.get_internal_state),
421
                (SetInternalStateReq, self.set_internal_state),
422
                (RpcReqInput, self.handle_rpc_request),
423
                (ExpertDistributionReq, self.expert_distribution_handle),
424
425
426
            ]
        )

Byron Hsu's avatar
Byron Hsu committed
427
428
429
430
431
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
        self.init_disaggregation()

432
433
    def init_tokenizer(self):
        server_args = self.server_args
Lianmin Zheng's avatar
Lianmin Zheng committed
434

435
436
437
438
439
440
441
        self.model_config = ModelConfig(
            server_args.model_path,
            trust_remote_code=server_args.trust_remote_code,
            revision=server_args.revision,
            context_length=server_args.context_length,
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
442
            enable_multimodal=server_args.enable_multimodal,
443
444
445
446
            dtype=server_args.dtype,
            quantization=server_args.quantization,
        )
        self.is_generation = self.model_config.is_generation
447

448
449
450
451
452
453
454
455
456
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if self.model_config.is_multimodal:
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
457
                    use_fast=not server_args.disable_fast_image_processor,
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )

    def init_memory_pool_and_cache(self):
        server_args = self.server_args

        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            self.tp_worker.get_memory_pool()
        )

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
            )
        else:
            if self.enable_hierarchical_cache:
                self.tree_cache = HiRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
488
                    tp_cache_group=self.tp_cpu_group,
489
                    page_size=self.page_size,
490
                    hicache_ratio=server_args.hicache_ratio,
491
492
493
494
495
                )
            else:
                self.tree_cache = RadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
496
                    page_size=self.page_size,
497
498
499
500
501
502
503
504
505
506
507
508
509
                    disable=server_args.disable_radix_cache,
                )

        self.decode_mem_cache_buf_multiplier = (
            1
            if self.spec_algorithm.is_none()
            else (
                server_args.speculative_num_draft_tokens
                + (
                    server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                )
            )
510
        )
511
512
513
514
515
516
517

    def init_metrics(self):
        # The largest prefill length of a single request
        self._largest_prefill_len: int = 0
        # The largest context length (prefill + generation) of a single request
        self._largest_prefill_decode_len: int = 0
        self.last_gen_throughput: float = 0.0
Lianmin Zheng's avatar
Lianmin Zheng committed
518
        self.last_input_throughput: float = 0.0
519
520
521
522
523
524
525
526
527
528
529
530
531
532
        self.step_time_dict = defaultdict(list)  # Dict[batch size -> step time]
        self.spec_num_total_accepted_tokens = 0
        self.spec_num_total_forward_ct = 0
        self.cum_spec_accept_length = 0
        self.cum_spec_accept_count = 0
        self.stats = SchedulerStats()
        if self.enable_metrics:
            engine_type = "unified"
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    "engine_type": engine_type,
                },
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
533

Byron Hsu's avatar
Byron Hsu committed
534
    def init_disaggregation(self):
535
536
537
538
        self.transfer_backend = TransferBackend(
            self.server_args.disaggregation_transfer_backend
        )

Byron Hsu's avatar
Byron Hsu committed
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
        if (
            self.disaggregation_mode == DisaggregationMode.DECODE
        ):  # *2 for the headroom.
            buffer_size = (self.req_to_token_pool.size) * 2
            req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
                buffer_size
            )
            aux_dtype = torch.int32
            # A list of metadata buffers. The shape is (b, metadata_size) where
            # b corresponds to a max running requests. The last shape * dtype.itemsize
            # should be larger than 64 bytes to work with RDMA, so we pad it.
            output_id_buffer = torch.zeros(
                (buffer_size, 16), dtype=aux_dtype, device="cpu"
            )
            metadata_buffers = [output_id_buffer]

            # The decode requests polling kv cache
            self.disagg_decode_transfer_queue = DecodeTransferQueue(
557
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
558
559
560
561
562
563
564
565
566
567
568
569
570
571
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
                metadata_buffers=metadata_buffers,
            )

            # The decode requests pending for pre-allocation
            self.disagg_decode_prealloc_queue = DecodePreallocQueue(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
                metadata_buffers=metadata_buffers,
                aux_dtype=aux_dtype,
                scheduler=self,
                transfer_queue=self.disagg_decode_transfer_queue,
                tree_cache=self.tree_cache,
572
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
573
574
575
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
576
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
            )
        elif self.disaggregation_mode == DisaggregationMode.PREFILL:
            # *2 for the headroom.
            buffer_size = self.max_running_requests * 2
            req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
                buffer_size
            )
            aux_dtype = torch.int32
            # A list of metadata buffers. The shape is (b, metadata_size) where
            # b corresponds to a max running requests. The last shape * dtype.itemsize
            # should be larger than 64 bytes to work with RDMA, so we pad it.
            output_id_buffer = torch.zeros(
                (buffer_size, 16), dtype=aux_dtype, device="cpu"
            )
            metadata_buffers = [output_id_buffer]

            self.disagg_prefill_pending_queue = PrefillBootstrapQueue(
                token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
                metadata_buffers=metadata_buffers,
                aux_dtype=aux_dtype,
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
601
                gloo_group=self.attn_tp_cpu_group,
602
                transfer_backend=self.transfer_backend,
603
                scheduler=self,
Byron Hsu's avatar
Byron Hsu committed
604
605
            )
            # The prefill requests that are in the middle of kv sending
606
            self.disagg_prefill_inflight_queue: List[Req] = []
Byron Hsu's avatar
Byron Hsu committed
607

608
    @DynamicGradMode()
609
    def event_loop_normal(self):
610
        """A normal scheduler loop."""
611
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
612
613
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
614

615
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
616
            self.cur_batch = batch
617
618
619
620

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
621
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
622
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
623
                self.check_memory()
624
                self.new_token_ratio = self.init_new_token_ratio
625
626

            self.last_batch = batch
627

628
    @DynamicGradMode()
Lianmin Zheng's avatar
Lianmin Zheng committed
629
    def event_loop_overlap(self):
630
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
631
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
632
633
634
635
636
637
638

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
639

Lianmin Zheng's avatar
Lianmin Zheng committed
640
641
            if batch:
                result = self.run_batch(batch)
642
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
643

644
                if self.last_batch is None:
645
                    # Create a dummy first batch to start the pipeline for overlap schedule.
646
647
648
649
650
651
652
653
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
                    self.process_batch_result(tmp_batch, None)

Lianmin Zheng's avatar
Lianmin Zheng committed
654
            if self.last_batch:
655
                # Process the results of the last batch
656
                tmp_batch, tmp_result = self.result_queue.popleft()
657
658
659
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
660
661
                self.process_batch_result(tmp_batch, tmp_result)
            elif batch is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
662
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
663
                self.check_memory()
664
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
665
666
667

            self.last_batch = batch

668
669
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
670
        if self.attn_tp_rank == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
671
672
            recv_reqs = []

673
674
675
676
677
            while True:
                try:
                    recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
678
                recv_reqs.append(recv_req)
679
680
681
682
683
684
685

            while True:
                try:
                    recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
                recv_reqs.append(recv_rpc)
Lianmin Zheng's avatar
Lianmin Zheng committed
686
687
        else:
            recv_reqs = None
688

689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
                work_reqs = broadcast_pyobj(
                    work_reqs,
                    self.attn_tp_rank,
                    self.attn_tp_cpu_group,
                    src=attn_tp_rank_0,
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
                    control_reqs, self.tp_rank, self.tp_cpu_group
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
723
            recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
724
725
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
726
    def process_input_requests(self, recv_reqs: List):
727
        for recv_req in recv_reqs:
728
729
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
Lianmin Zheng's avatar
Lianmin Zheng committed
730
                self.chunked_req is not None or not self.running_batch.is_empty()
731
732
733
734
            ):
                self.return_health_check_ct += 1
                continue

735
            output = self._request_dispatcher(recv_req)
736
            if output is not None:
737
738
739
740
741
                if isinstance(output, RpcReqOutput):
                    if self.recv_from_rpc is not None:
                        self.recv_from_rpc.send_pyobj(output)
                else:
                    self.send_to_tokenizer.send_pyobj(output)
742
743
744
745
746

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
747
        # Create a new request
748
749
750
751
752
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
753
754
755
756
757
758
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

759
760
761
762
763
764
765
766
767
768
769
770
771
            # Handle custom logit processor passed to the request
            custom_logit_processor = recv_req.custom_logit_processor
            if (
                not self.server_args.enable_custom_logit_processor
                and custom_logit_processor is not None
            ):
                logger.warning(
                    "The SGLang server is not configured to enable custom logit processor."
                    "The custom logit processor passed in will be ignored."
                    "Please set --enable-custom-logits-processor to enable this feature."
                )
                custom_logit_processor = None

772
773
774
775
776
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
777
778
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
779
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
780
                stream=recv_req.stream,
781
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
782
                input_embeds=recv_req.input_embeds,
783
                custom_logit_processor=custom_logit_processor,
784
                return_hidden_states=recv_req.return_hidden_states,
785
                eos_token_ids=self.model_config.hf_eos_token_id,
786
787
                bootstrap_host=recv_req.bootstrap_host,
                bootstrap_room=recv_req.bootstrap_room,
788
789
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
790

791
792
793
794
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
795
                req.finished_reason = FINISH_ABORT(
796
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
797
                )
798
                self._add_request_to_queue(req)
799
800
                return
        else:
801
802
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
803
            req = session.create_req(recv_req, self.tokenizer)
804
            if isinstance(req.finished_reason, FINISH_ABORT):
805
                self._add_request_to_queue(req)
806
                return
807

808
        # Handle multimodal inputs
Mick's avatar
Mick committed
809
810
        if recv_req.mm_inputs is not None:
            image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
811
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
812
            req.origin_input_ids = self.pad_input_ids_func(
813
                req.origin_input_ids, image_inputs
814
            )
815
            req.extend_image_inputs(image_inputs)
816

817
            if len(req.origin_input_ids) >= self.max_req_input_len:
818
                error_msg = (
819
                    "Multimodal prompt is too long after expanding multimodal tokens. "
820
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
821
                )
822
                logger.error(error_msg)
823
                req.origin_input_ids = [0]
Mick's avatar
Mick committed
824
                req.multimodal_inputs = None
825
                req.sampling_params.max_new_tokens = 0
826
                req.finished_reason = FINISH_ABORT(
827
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
828
                )
829
                self._add_request_to_queue(req)
830
831
                return

832
833
834
835
836
837
838
        # Validate prompts length
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
839
840
            req.origin_input_ids = [0]
            req.sampling_params.max_new_tokens = 0
841
            self._add_request_to_queue(req)
842
            return
843

844
        # Copy more attributes
845
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
846
847
848
849
850
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(req.origin_input_ids) - 1
        else:
            req.logprob_start_len = recv_req.logprob_start_len

851
852
853
854
855
856
857
858
859
860
        if req.logprob_start_len >= len(req.origin_input_ids):
            req.finished_reason = FINISH_ABORT(
                f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
                HTTPStatus.BAD_REQUEST,
                "BadRequestError",
            )
            req.logprob_start_len = len(req.origin_input_ids) - 1
            self._add_request_to_queue(req)
            return

861
862
863
864
865
866
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
867
            self.max_req_len - len(req.origin_input_ids) - 1,
868
869
        )

870
871
872
873
874
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
875
            or req.sampling_params.ebnf is not None
876
            or req.sampling_params.structural_tag is not None
877
878
879
880
881
882
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
883
884
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
885
886
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
887
888
889
890
891
892
893

            req.grammar = self.grammar_backend.get_cached_value(key)
            if not req.grammar:
                req.grammar = self.grammar_backend.get_future_value(key)
                add_to_grammar_queue = True

        if add_to_grammar_queue:
894
895
            self.grammar_queue.append(req)
        else:
896
897
898
            self._add_request_to_queue(req)

    def _add_request_to_queue(self, req: Req):
899
        req.queue_time_start = time.time()
Byron Hsu's avatar
Byron Hsu committed
900
901
902
903
904
905
906
907
908
909
910
911
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            self.disagg_prefill_pending_queue.add(req)
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.add(req)
        else:
            self.waiting_queue.append(req)

    def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
        if self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.extend(reqs)
        else:
            self.waiting_queue.extend(reqs)
912
913
914

    def handle_embedding_request(
        self,
915
        recv_req: TokenizedEmbeddingReqInput,
916
917
918
919
920
921
922
923
924
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

925
926
        # Handle multimodal inputs
        if recv_req.image_inputs is not None:
Mick's avatar
Mick committed
927
            image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
928
929
930
931
932
933
934
935
936
937
938
939
940
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
            req.origin_input_ids = self.pad_input_ids_func(
                req.origin_input_ids, image_inputs
            )
            req.extend_image_inputs(image_inputs)

            if len(req.origin_input_ids) >= self.max_req_input_len:
                error_msg = (
                    "Multimodal prompt is too long after expanding multimodal tokens. "
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                )
                logger.error(error_msg)
                req.origin_input_ids = [0]
Mick's avatar
Mick committed
941
                req.multimodal_inputs = None
942
943
944
945
                req.sampling_params.max_new_tokens = 0
                req.finished_reason = FINISH_ABORT(
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
                )
946
                req.queue_time_start = time.time()
947
948
949
                self.waiting_queue.append(req)
                return

950
        # Validate prompts length
951
        error_msg = validate_input_length(
952
953
954
955
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
956
        if error_msg:
957
            self._add_request_to_queue(req)
958
            return
959

960
961
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
962
        self._add_request_to_queue(req)
963

964
965
966
967
    def log_prefill_stats(
        self,
        adder: PrefillAdder,
        can_run_list: List[Req],
968
        running_bs: int,
969
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
970
971
972
973
974
        gap_latency = time.time() - self.last_prefill_stats_tic
        self.last_prefill_stats_tic = time.time()
        self.last_input_throughput = self.num_prefill_tokens / gap_latency
        self.num_prefill_tokens = 0

975
        num_used = self.max_total_num_tokens - (
976
977
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
978
        )
979
980
981
        self._largest_prefill_len = max(
            self._largest_prefill_len, adder.log_input_tokens
        )
982

983
        num_new_seq = len(can_run_list)
984
        f = (
985
            f"Prefill batch. "
986
            f"#new-seq: {num_new_seq}, "
987
988
989
990
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
991
            f"#queue-req: {len(self.waiting_queue)}, "
992
        )
993
        logger.info(f)
994
995

        if self.enable_metrics:
996
997
998
            cache_hit_rate = adder.log_hit_tokens / (
                adder.log_input_tokens + adder.log_hit_tokens
            )
999
1000
1001
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
1002
1003
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.stats.cache_hit_rate = cache_hit_rate
1004
1005
1006
1007
1008
1009

            total_queue_latency = 0
            for req in can_run_list:
                total_queue_latency += req.queue_time_end - req.queue_time_start
            self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq

1010
1011
1012
            self.metrics_collector.log_stats(self.stats)

    def log_decode_stats(self):
1013
1014
1015
1016
        gap_latency = time.time() - self.last_decode_stats_tic
        self.last_decode_stats_tic = time.time()
        self.last_gen_throughput = self.num_generated_tokens / gap_latency
        self.num_generated_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1017
        num_running_reqs = len(self.running_batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1018
        num_used = self.max_total_num_tokens - (
1019
1020
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1021
        )
1022
1023
1024
1025
1026

        if RECORD_STEP_TIME:
            self.step_time_dict[num_running_reqs].append(
                gap_latency / self.server_args.decode_log_interval
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1027

1028
1029
1030
1031
1032
1033
        if self.spec_algorithm.is_none():
            msg = (
                f"Decode batch. "
                f"#running-req: {num_running_reqs}, "
                f"#token: {num_used}, "
                f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
1034
1035
                f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
                f"#queue-req: {len(self.waiting_queue)}, "
1036
            )
1037
            spec_accept_length = 0
1038
        else:
1039
            spec_accept_length = (
1040
1041
                self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
            )
1042
1043
            self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
            self.cum_spec_accept_count += self.spec_num_total_forward_ct
1044
1045
1046
1047
1048
1049
            self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
            msg = (
                f"Decode batch. "
                f"#running-req: {num_running_reqs}, "
                f"#token: {num_used}, "
                f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
1050
                f"accept len: {spec_accept_length:.2f}, "
1051
1052
                f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
                f"#queue-req: {len(self.waiting_queue)}, "
1053
1054
1055
            )

        logger.info(msg)
1056
1057
1058
1059
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
1060
1061
            self.stats.cache_hit_rate = 0.0
            self.stats.gen_throughput = self.last_gen_throughput
1062
            self.stats.num_queue_reqs = len(self.waiting_queue)
1063
            self.stats.spec_accept_length = spec_accept_length
1064
1065
            self.metrics_collector.log_stats(self.stats)

Lianmin Zheng's avatar
Lianmin Zheng committed
1066
1067
    def check_memory(self):
        available_size = (
1068
1069
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1070
        )
1071
1072
1073
1074
1075
1076
1077
        protected_size = self.tree_cache.protected_size()
        memory_leak = available_size != (
            self.max_total_num_tokens
            if not self.enable_hierarchical_cache
            else self.max_total_num_tokens - protected_size
        )
        if memory_leak:
1078
            msg = (
1079
                "token_to_kv_pool_allocator memory leak detected! "
1080
                f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1081
1082
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1083
            )
1084
1085
1086
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1087
1088

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
1089
            msg = (
1090
                "req_to_token_pool memory leak detected!"
1091
1092
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1093
            )
1094
1095
1096
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1097

1098
1099
1100
1101
1102
1103
1104
        if (
            self.enable_metrics
            and self.attn_tp_rank == 0
            and time.time() > self.metrics_collector.last_log_time + 30
        ):
            # During idle time, also collect metrics every 30 seconds.
            num_used = self.max_total_num_tokens - (
1105
                self.token_to_kv_pool_allocator.available_size()
1106
1107
                + self.tree_cache.evictable_size()
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1108
            num_running_reqs = len(self.running_batch.reqs)
1109
1110
1111
1112
1113
1114
1115
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.metrics_collector.log_stats(self.stats)

1116
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1117
        # Merge the prefill batch into the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1118
        if self.last_batch and self.last_batch.forward_mode.is_extend():
1119
1120
1121
1122
1123
1124
1125
            if self.chunked_req:
                # Move the chunked request out of the batch so that we can merge
                # only finished requests to running_batch.
                self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
                self.tree_cache.cache_unfinished_req(self.chunked_req)
                # chunked request keeps its rid but will get a new req_pool_idx
                self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
Lianmin Zheng's avatar
Lianmin Zheng committed
1126
                self.running_batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1127

1128
            # Filter batch
1129
            last_bs = self.last_batch.batch_size()
1130
            self.last_batch.filter_batch()
1131
            if self.last_batch.batch_size() < last_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1132
                self.running_batch.batch_is_full = False
1133

1134
            # Merge the new batch into the running batch
1135
            if not self.last_batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1136
                if self.running_batch.is_empty():
1137
1138
                    self.running_batch = self.last_batch
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1139
                    # Merge running_batch with prefill batch
1140
                    self.running_batch.merge_batch(self.last_batch)
1141

1142
1143
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
1144
1145
1146
1147
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
1148
            if not self.running_batch.is_empty():
1149
                self.running_batch = self.update_running_batch(self.running_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1150
1151
1152
                ret = self.running_batch if not self.running_batch.is_empty() else None
            else:
                ret = None
1153

1154
        # Handle DP attention
1155
        if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
Lianmin Zheng's avatar
Lianmin Zheng committed
1156
            ret, _ = self.prepare_dp_attn_batch(ret)
1157
1158

        return ret
1159

Lianmin Zheng's avatar
Lianmin Zheng committed
1160
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
1161
        # Check if the grammar is ready in the grammar queue
1162
        if self.grammar_queue:
1163
            self.move_ready_grammar_requests()
1164

Lianmin Zheng's avatar
Lianmin Zheng committed
1165
1166
        # Handle the cases where prefill is not allowed
        if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1167
            self.running_batch.batch_is_full or len(self.waiting_queue) == 0
1168
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1169
1170
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1171
        running_bs = len(self.running_batch.reqs)
1172
        if running_bs >= self.max_running_requests:
Lianmin Zheng's avatar
Lianmin Zheng committed
1173
            self.running_batch.batch_is_full = True
1174
1175
            return None

1176
1177
1178
1179
1180
        if self.enable_hierarchical_cache:
            # check for completion of hierarchical cache activities to release memory
            self.tree_cache.writing_check()
            self.tree_cache.loading_check()

1181
1182
1183
        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
1184
        # Prefill policy
1185
1186
        adder = PrefillAdder(
            self.tree_cache,
1187
            self.token_to_kv_pool_allocator,
1188
1189
1190
1191
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1192
            running_bs if self.is_mixed_chunk else 0,
1193
1194
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1195
        if self.chunked_req is not None:
1196
1197
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1198

Lianmin Zheng's avatar
Lianmin Zheng committed
1199
        if self.lora_paths:
Lianmin Zheng's avatar
Lianmin Zheng committed
1200
1201
            lora_set = set([req.lora_path for req in self.running_batch.reqs])

1202
        # Get requests from the waiting queue to a new prefill batch
1203
1204
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1205
                self.lora_paths
1206
1207
1208
1209
1210
1211
1212
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1213
                self.running_batch.batch_is_full = True
1214
1215
                break

1216
            if running_bs + len(adder.can_run_list) >= self.max_running_requests:
Lianmin Zheng's avatar
Lianmin Zheng committed
1217
                self.running_batch.batch_is_full = True
1218
                break
1219

1220
1221
1222
1223
            req.init_next_round_input(
                None if prefix_computed else self.tree_cache,
                self.enable_hierarchical_cache,
            )
1224

1225
1226
1227
            res = adder.add_one_req(
                req, self.chunked_req, self.enable_hierarchical_cache
            )
1228
1229
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1230
1231
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
Lianmin Zheng's avatar
Lianmin Zheng committed
1232
1233
1234
                        self.running_batch.batch_is_full = len(
                            adder.can_run_list
                        ) > 0 or (
1235
1236
1237
1238
                            self.running_batch is not None
                            and not self.running_batch.is_empty()
                        )
                    else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1239
                        self.running_batch.batch_is_full = True
1240
1241
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1242
        # Update waiting queue
1243
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1244
1245
        if len(can_run_list) == 0:
            return None
1246
1247
1248
1249
1250
1251

        if self.enable_metrics:
            # only record queue time when enable_metrics is True to avoid overhead
            for req in can_run_list:
                req.queue_time_end = time.time()

Lianmin Zheng's avatar
Lianmin Zheng committed
1252
1253
1254
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1255

1256
        if self.enable_hierarchical_cache:
1257
            self.tree_cache.ready_to_load_cache()
1258

1259
1260
1261
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1262

1263
1264
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1265

1266
        # Print stats
1267
        if self.attn_tp_rank == 0:
1268
            self.log_prefill_stats(adder, can_run_list, running_bs)
1269

Lianmin Zheng's avatar
Lianmin Zheng committed
1270
        # Create a new batch
1271
1272
1273
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
1274
            self.token_to_kv_pool_allocator,
1275
            self.tree_cache,
1276
            self.model_config,
1277
            self.enable_overlap,
1278
            self.spec_algorithm,
1279
            self.server_args.enable_custom_logit_processor,
1280
        )
1281
        new_batch.prepare_for_extend()
1282

Lianmin Zheng's avatar
Lianmin Zheng committed
1283
        # Mixed-style chunked prefill
1284
1285
        if (
            self.is_mixed_chunk
Lianmin Zheng's avatar
Lianmin Zheng committed
1286
            and not self.running_batch.is_empty()
1287
1288
1289
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1290
1291
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1292
                self.running_batch.prepare_for_decode()
1293
1294
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1295
1296
1297
            self.running_batch = ScheduleBatch(
                reqs=[], batch_is_full=self.running_batch.batch_is_full
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1298
1299
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1300
1301
1302

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1303
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1304
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1305
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1306

1307
1308
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1309
1310
            batch.batch_is_full = False
            return batch
1311

Lianmin Zheng's avatar
Lianmin Zheng committed
1312
        # Check if decode out of memory
1313
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1314
            TEST_RETRACT and batch.batch_size() > 10
1315
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1316
1317
            old_ratio = self.new_token_ratio

1318
            retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
Lianmin Zheng's avatar
Lianmin Zheng committed
1319
            self.new_token_ratio = new_token_ratio
1320

Lianmin Zheng's avatar
Lianmin Zheng committed
1321
1322
1323
1324
1325
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
1326
            self._extend_requests_to_queue(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1327
1328
        else:
            self.new_token_ratio = max(
1329
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1330
1331
1332
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1333
        if batch.batch_size() < initial_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1334
            batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1335
1336

        # Update batch tensors
1337
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1338
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1339

1340
1341
1342
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1343
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1344
1345
        self.forward_ct += 1

1346
1347
1348
1349
1350
1351
1352
        # Check profiler
        if (
            self.profiler_target_forward_ct
            and self.profiler_target_forward_ct <= self.forward_ct
        ):
            self.stop_profile()

1353
        # Run forward
1354
        if self.is_generation:
1355
1356
1357
1358
1359
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
                logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
                    model_worker_batch
                )
1360
                bid = model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
1361
            else:
1362
1363
1364
                (
                    logits_output,
                    next_token_ids,
1365
                    bid,
1366
1367
1368
1369
1370
1371
1372
                    num_accepted_tokens,
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
                self.spec_num_total_accepted_tokens += (
                    num_accepted_tokens + batch.batch_size()
                )
                self.spec_num_total_forward_ct += batch.batch_size()
                self.num_generated_tokens += num_accepted_tokens
1373
            batch.output_ids = next_token_ids
1374

1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
            if batch.return_logprob:
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_input_len_per_req = None
                extend_logprob_start_len_per_req = None

1387
1388
1389
            ret = GenerationBatchResult(
                logits_output=logits_output,
                next_token_ids=next_token_ids,
1390
1391
                extend_input_len_per_req=extend_input_len_per_req,
                extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1392
                bid=bid,
1393
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1394
1395
1396
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1397
1398
1399
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1400
        return ret
Chayenne's avatar
Chayenne committed
1401

1402
1403
1404
1405
1406
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1407
1408
        if batch.forward_mode.is_decode():
            self.process_batch_result_decode(batch, result)
1409
        elif batch.forward_mode.is_extend():
Lianmin Zheng's avatar
Lianmin Zheng committed
1410
            self.process_batch_result_prefill(batch, result)
1411
1412
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1413
                self.tp_worker.resolve_batch_result(result.bid)
1414
1415
1416
1417
                if batch.next_batch_sampling_info:
                    batch.next_batch_sampling_info.update_regex_vocab_mask()
                    self.current_stream.synchronize()
                    batch.next_batch_sampling_info.sampling_info_done.set()
1418
1419
        elif batch.forward_mode.is_dummy_first():
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1420
            self.current_stream.synchronize()
1421
            batch.next_batch_sampling_info.sampling_info_done.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
1422

1423
1424
1425
1426
1427
1428
1429
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

1430
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
        return self.prepare_dp_attn_batch_raw(
            local_batch,
            dp_size=self.server_args.dp_size,
            attn_tp_size=self.attn_tp_size,
            tp_cpu_group=self.tp_cpu_group,
            get_idle_batch=self.get_idle_batch,
            disable_cuda_graph=self.server_args.disable_cuda_graph,
            spec_algorithm=self.spec_algorithm,
            speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
        )

    @staticmethod
    def prepare_dp_attn_batch_raw(
        local_batch: ScheduleBatch,
        dp_size,
        attn_tp_size: int,
        tp_cpu_group,
        get_idle_batch,
        disable_cuda_graph: bool,
        spec_algorithm,
        speculative_num_draft_tokens,
    ):
1453
1454
1455
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1456
            global_num_tokens_for_logprob = 0
1457
1458
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
1459
1460
            if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
                num_tokens = num_tokens * speculative_num_draft_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
1461
            global_num_tokens_for_logprob = num_tokens
1462
1463
        else:
            num_tokens = local_batch.extend_num_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
            global_num_tokens_for_logprob = sum(
                [
                    # We should have at least 1 token for sample in every case.
                    max(extend_len - logprob_start_len, 1)
                    for logprob_start_len, extend_len in zip(
                        local_batch.extend_logprob_start_lens, local_batch.extend_lens
                    )
                ]
            )

        if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
            can_cuda_graph = 1
        else:
            can_cuda_graph = 0

1479
        if not spec_algorithm.is_none():
Lianmin Zheng's avatar
Lianmin Zheng committed
1480
1481
1482
            # TODO(sang): Support cuda graph when idle batch is there.
            if local_batch is None or local_batch.forward_mode.is_idle():
                can_cuda_graph = 0
1483

Lianmin Zheng's avatar
Lianmin Zheng committed
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
        is_extend_in_batch = (
            local_batch.forward_mode.is_extend() if local_batch else False
        )
        local_info = torch.tensor(
            [
                num_tokens,
                can_cuda_graph,
                global_num_tokens_for_logprob,
                is_extend_in_batch,
            ],
            dtype=torch.int64,
        )
        global_info = torch.empty(
1497
            (dp_size, attn_tp_size, 4),
Lianmin Zheng's avatar
Lianmin Zheng committed
1498
1499
            dtype=torch.int64,
        )
1500
        torch.distributed.all_gather_into_tensor(
Lianmin Zheng's avatar
Lianmin Zheng committed
1501
1502
            global_info.flatten(),
            local_info,
1503
            group=tp_cpu_group,
1504
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1505
1506
1507
1508
        global_num_tokens = global_info[:, 0, 0].tolist()
        can_cuda_graph = min(global_info[:, 0, 1].tolist())
        global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
        is_extend_in_batch = global_info[:, 0, 3].tolist()
1509

Lianmin Zheng's avatar
Lianmin Zheng committed
1510
        if local_batch is None and max(global_num_tokens) > 0:
1511
            local_batch = get_idle_batch()
1512
1513

        if local_batch is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1514
1515
            local_batch.global_num_tokens = global_num_tokens
            local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
1516
1517

            # Check forward mode for cuda graph
1518
            if not disable_cuda_graph:
Lianmin Zheng's avatar
Lianmin Zheng committed
1519
                local_batch.can_run_dp_cuda_graph = can_cuda_graph
1520

Lianmin Zheng's avatar
Lianmin Zheng committed
1521
        return local_batch, any(is_extend_in_batch)
1522
1523
1524
1525
1526

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
1527
            self.token_to_kv_pool_allocator,
1528
1529
1530
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
1531
            self.spec_algorithm,
1532
            self.server_args.enable_custom_logit_processor,
1533
1534
1535
1536
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
        num_ready_reqs = 0
        for req in self.grammar_queue:
            try:
                req.grammar = req.grammar.result(timeout=0.05)
                num_ready_reqs += 1
            except futures._base.TimeoutError:
                break

1547
        if self.server_args.enable_dp_attention:
1548
1549
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
1550
        else:
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
            tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
            num_ready_reqs_max = tensor.item()
            for i in range(num_ready_reqs, num_ready_reqs_max):
                self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
            num_ready_reqs = num_ready_reqs_max
1564

1565
        self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1566
1567
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
    def watchdog_thread(self):
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
        self.watchdog_last_forward_ct = 0
        self.watchdog_last_time = time.time()

        while True:
            current = time.time()
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if current > self.watchdog_last_time + self.watchdog_timeout:
                        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)

        # Print batch size and memory pool info to check whether there are de-sync issues.
        logger.error(
            f"{self.cur_batch.batch_size()=}, "
            f"{self.cur_batch.reqs=}, "
            f"{self.token_to_kv_pool_allocator.available_size()=}, "
            f"{self.tree_cache.evictable_size()=}, "
        )
        # Wait for some time so that the parent process can print the error.
        pyspy_dump_schedulers()
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
        time.sleep(5)
        self.parent_process.send_signal(signal.SIGQUIT)

1599
1600
1601
    def flush_cache_wrapped(self, recv_req: FlushCacheReq):
        self.flush_cache()

1602
    def flush_cache(self):
1603
        """Flush the memory pool and cache."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1604
        if len(self.waiting_queue) == 0 and self.running_batch.is_empty():
1605
1606
            self.cur_batch = None
            self.last_batch = None
1607
            self.tree_cache.reset()
1608
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1609
                self.grammar_backend.reset()
1610
            self.req_to_token_pool.clear()
1611
            self.token_to_kv_pool_allocator.clear()
1612
1613
1614

            if not self.spec_algorithm.is_none():
                self.draft_worker.model_runner.req_to_token_pool.clear()
1615
                self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
1616
1617
1618
1619
1620

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
1621
1622
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
1623
1624
1625
1626
1627
1628
1629
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
1630
                f"#running-req: {len(self.running_batch.reqs)}"
1631
1632
1633
1634
            )
            if_success = False
        return if_success

1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )

        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
        return GetInternalStateReqOutput(
            internal_state=ret,
        )

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
            logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
    def handle_rpc_request(self, recv_req: RpcReqInput):
        # Handle RPC requests
        logger.info(
            f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
        )

        success = True
        exec = None
        try:
            func = getattr(self, recv_req.method)
            func(recv_req.parameters)
        except Exception as e:
            success = False
            exec = e
            logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")

        barrier()
        return RpcReqOutput(success, "" if not exec else str(exec))

    def save_remote_model(self, params):
        url = params["url"]

1700
        worker = self.tp_worker.worker
1701
1702
1703
1704

        worker.model_runner.save_remote_model(url)

    def save_sharded_model(self, params):
1705
        worker = self.tp_worker.worker
1706
1707
1708
1709
1710
1711
1712

        worker.model_runner.save_sharded_model(
            path=params["path"],
            pattern=params["pattern"],
            max_size=params["max_size"],
        )

1713
1714
    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
Lianmin Zheng's avatar
Lianmin Zheng committed
1715
        to_del = []
1716
        for i, req in enumerate(self.waiting_queue):
Lianmin Zheng's avatar
Lianmin Zheng committed
1717
1718
            if req.rid.startswith(recv_req.rid):
                to_del.append(i)
1719
1720
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1721
1722
1723
        # Sort in reverse order to avoid index issues when deleting
        for i in sorted(to_del, reverse=True):
            req = self.waiting_queue.pop(i)
1724
1725
            logger.debug(f"Abort queued request. {req.rid=}")
            return
1726
1727

        # Delete requests in the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1728
1729
1730
1731
1732
        for req in self.running_batch.reqs:
            if req.rid.startswith(recv_req.rid) and not req.finished():
                logger.debug(f"Abort running request. {req.rid=}")
                req.to_abort = True
                return
1733

1734
1735
1736
    def _pause_engine(self) -> Tuple[List[Req], int]:
        raise NotImplementedError()

Chayenne's avatar
Chayenne committed
1737
1738
1739
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
1740
1741
1742
1743
1744
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
1745
        return UpdateWeightFromDiskReqOutput(success, message, 0)
1746

1747
1748
1749
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
1750
        return InitWeightsUpdateGroupReqOutput(success, message)
1751
1752

    def update_weights_from_distributed(
1753
1754
1755
        self,
        recv_req: UpdateWeightsFromDistributedReqInput,
    ) -> Tuple[bool, str]:
1756
1757
1758
1759
1760
1761
1762
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
1763
        return UpdateWeightsFromDistributedReqOutput(success, message)
1764

1765
1766
1767
1768
1769
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        """Update the online model parameter from tensors."""
        success, message = self.tp_worker.update_weights_from_tensor(recv_req)
        # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
        if success:
1770
1771
1772
            if recv_req.flush_cache:
                flash_cache_success = self.flush_cache()
                assert flash_cache_success, "Cache flush failed after updating weights"
1773
1774
        else:
            logger.error(message)
1775
        return UpdateWeightsFromTensorReqOutput(success, message)
1776

1777
1778
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
1779
        return GetWeightsByNameReqOutput(parameter)
1780

1781
    def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
1782
1783
1784
        self.memory_saver_adapter.check_validity(
            caller_name="release_memory_occupation"
        )
1785
1786
1787
1788
1789
        self.stashed_model_static_state = _export_static_state(
            self.tp_worker.worker.model_runner.model
        )
        self.memory_saver_adapter.pause()
        self.flush_cache()
1790
        return ReleaseMemoryOccupationReqOutput()
1791

1792
    def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
1793
        self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
1794
1795
1796
1797
1798
        self.memory_saver_adapter.resume()
        _import_static_state(
            self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
        )
        del self.stashed_model_static_state
1799
1800
1801
        return ResumeMemoryOccupationReqOutput()

    def profile(self, recv_req: ProfileReq):
1802
1803
        if recv_req.type == ProfileReqType.START_PROFILE:
            return self.start_profile(
1804
1805
1806
1807
1808
                recv_req.output_dir,
                recv_req.num_steps,
                recv_req.activities,
                recv_req.with_stack,
                recv_req.record_shapes,
1809
                recv_req.profile_id,
1810
            )
1811
        else:
1812
1813
1814
1815
1816
1817
1818
            return self.stop_profile()

    def start_profile(
        self,
        output_dir: Optional[str],
        num_steps: Optional[int],
        activities: Optional[List[str]],
1819
1820
        with_stack: Optional[bool],
        record_shapes: Optional[bool],
1821
        profile_id: Optional[str],
1822
    ) -> None:
1823
        if self.profiler_activities:
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
            return ProfileReqOutput(
                success=False,
                message="Profiling is already in progress. Call /stop_profile first.",
            )

        if output_dir is None:
            output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
        if activities is None:
            activities = ["CPU", "GPU"]

        self.torch_profiler_output_dir = output_dir
1835
        self.profiler_activities = activities
1836
        self.profiler_id = profile_id
1837
        logger.info(
1838
            "Profiling starts. Traces will be saved to: %s (with id %s)",
1839
            self.torch_profiler_output_dir,
1840
            self.profiler_id,
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
        )

        activity_map = {
            "CPU": torch.profiler.ProfilerActivity.CPU,
            "GPU": torch.profiler.ProfilerActivity.CUDA,
        }
        torchprof_activities = [
            activity_map[a] for a in activities if a in activity_map
        ]

        if torchprof_activities:
            self.torch_profiler = torch.profiler.profile(
                activities=torchprof_activities,
1854
1855
                with_stack=with_stack if with_stack is not None else True,
                record_shapes=record_shapes if record_shapes is not None else False,
1856
1857
1858
1859
1860
            )
            self.torch_profiler.start()

        if "MEM" in activities:
            torch.cuda.memory._record_memory_history(max_entries=100000)
1861

1862
1863
1864
        if "CUDA_PROFILER" in activities:
            torch.cuda.cudart().cudaProfilerStart()

1865
1866
1867
1868
1869
1870
        if num_steps:
            self.profiler_target_forward_ct = self.forward_ct + num_steps
            # The caller will be notified when reaching profiler_target_forward_ct
        else:
            self.profiler_target_forward_ct = None
            return ProfileReqOutput(success=True, message="Succeeded")
1871
1872

    def stop_profile(self) -> None:
1873
        if self.profiler_activities is None:
1874
1875
1876
1877
1878
1879
1880
1881
            return

        logger.info("Stop profiling...")
        if self.torch_profiler is not None:
            self.torch_profiler.stop()
            self.torch_profiler.export_chrome_trace(
                os.path.join(
                    self.torch_profiler_output_dir,
1882
                    self.profiler_id + f"-TP-{self.tp_rank}" + ".trace.json.gz",
1883
1884
1885
                )
            )

1886
        if "MEM" in self.profiler_activities:
1887
            memory_profile_path = os.path.join(
1888
                self.torch_profiler_output_dir,
1889
                self.profiler_id + f"-TP-{self.tp_rank}-memory" + ".pickle",
1890
1891
1892
1893
            )
            torch.cuda.memory._dump_snapshot(memory_profile_path)
            torch.cuda.memory._record_memory_history(enabled=None)

1894
1895
1896
        if "CUDA_PROFILER" in self.profiler_activities:
            torch.cuda.cudart().cudaProfilerStop()

1897
1898
1899
        logger.info(
            "Profiling done. Traces are saved to: %s",
            self.torch_profiler_output_dir,
1900
        )
1901
1902
        self.torch_profiler = None
        self.torch_profiler_output_dir = None
1903
        self.profiler_activities = None
1904
1905
1906
1907
1908

        if self.profiler_target_forward_ct:
            self.send_to_tokenizer.send_pyobj(
                ProfileReqOutput(success=True, message="Succeeded.")
            )
1909

1910
1911
1912
1913
1914
1915
1916
1917
1918
    def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
        if recv_req == ExpertDistributionReq.START_RECORD:
            expert_distribution_recorder.start_record()
        elif recv_req == ExpertDistributionReq.STOP_RECORD:
            expert_distribution_recorder.stop_record()
        elif recv_req == ExpertDistributionReq.DUMP_RECORD:
            expert_distribution_recorder.dump_record()
        else:
            raise ValueError("Unrecognized ExpertDistributionReq value")
1919
        return ExpertDistributionReqOutput()
1920

1921
    def open_session(self, recv_req: OpenSessionReqInput):
1922
1923
1924
1925
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
1926
            return OpenSessionReqOutput(session_id, False)
1927
        elif session_id is None:
1928
            logger.warning("session id is None, cannot open.")
1929
            return OpenSessionReqOutput(session_id, False)
1930
1931
1932
1933
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
1934
            return OpenSessionReqOutput(session_id, True)
1935
1936
1937
1938
1939
1940
1941
1942
1943

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

1944

1945
1946
1947
1948
def is_health_check_generate_req(recv_req):
    return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")


1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
def _export_static_state(model):
    return dict(
        buffers=[
            (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
        ]
    )


def _import_static_state(model, static_params):
    self_named_buffers = dict(model.named_buffers())
    for name, tensor in static_params["buffers"]:
        self_named_buffers[name][...] = tensor


1963
1964
1965
1966
1967
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
1968
    dp_rank: Optional[int],
1969
    pipe_writer,
1970
):
1971
1972
1973
1974
1975
1976
    # Generate the prefix
    if dp_rank is None:
        prefix = f" TP{tp_rank}"
    else:
        prefix = f" DP{dp_rank} TP{tp_rank}"

1977
    # Config the process
1978
    kill_itself_when_parent_died()
1979
    setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
1980
    faulthandler.enable()
1981
    parent_process = psutil.Process().parent()
1982

1983
1984
1985
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
1986

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
1987
    # Configure the logger
1988
    configure_logger(server_args, prefix=prefix)
1989
    suppress_other_loggers()
1990

1991
    # Set cpu affinity to this gpu process
1992
1993
1994
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

1995
    # Create a scheduler and run the event loop
1996
    try:
1997
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1998
        pipe_writer.send(
Mick's avatar
Mick committed
1999
2000
2001
2002
2003
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2004
        )
Byron Hsu's avatar
Byron Hsu committed
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
        disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode

        if disaggregation_mode == DisaggregationMode.NULL:
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap()
            else:
                scheduler.event_loop_normal()
        elif disaggregation_mode == DisaggregationMode.PREFILL:
            scheduler.event_loop_normal_disagg_prefill()
        elif disaggregation_mode == DisaggregationMode.DECODE:
            scheduler.event_loop_normal_disagg_decode()

2017
    except Exception:
2018
2019
2020
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)