scheduler.py 77.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
23
import time
import warnings
24
from collections import defaultdict, deque
Lianmin Zheng's avatar
Lianmin Zheng committed
25
from concurrent import futures
26
from dataclasses import dataclass
27
from http import HTTPStatus
28
from types import SimpleNamespace
29
from typing import Dict, List, Optional, Tuple, Union
30

31
import psutil
32
import setproctitle
33
import torch
34
import zmq
35
from torch.distributed import barrier
36

37
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
38
from sglang.srt.configs.model_config import ModelConfig
39
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
Byron Hsu's avatar
Byron Hsu committed
40
41
42
43
44
45
46
47
48
49
50
51
52
from sglang.srt.disaggregation.decode import (
    DecodePreallocQueue,
    DecodeTransferQueue,
    SchedulerDisaggregationDecodeMixin,
)
from sglang.srt.disaggregation.prefill import (
    PrefillBootstrapQueue,
    SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
    ReqToMetadataIdxAllocator,
)
53
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
54
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
55
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
56
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
57
58
from sglang.srt.managers.io_struct import (
    AbortReq,
59
    CloseSessionReqInput,
60
    ExpertDistributionReq,
61
    ExpertDistributionReqOutput,
62
    FlushCacheReq,
63
64
    GetInternalStateReq,
    GetInternalStateReqOutput,
65
66
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
67
    HealthCheckOutput,
68
69
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
70
71
    OpenSessionReqInput,
    OpenSessionReqOutput,
72
    ProfileReq,
73
74
    ProfileReqOutput,
    ProfileReqType,
75
76
77
78
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
79
80
    RpcReqInput,
    RpcReqOutput,
81
82
    SetInternalStateReq,
    SetInternalStateReqOutput,
83
84
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
85
86
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
87
88
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
89
90
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
91
92
93
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
Mick's avatar
Mick committed
94
    MultimodalInputs,
95
96
    Req,
    ScheduleBatch,
97
    global_server_args_dict,
98
)
99
100
101
102
103
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
104
105
106
from sglang.srt.managers.scheduler_output_processor_mixin import (
    SchedulerOutputProcessorMixin,
)
107
from sglang.srt.managers.session_controller import Session
108
from sglang.srt.managers.tp_worker import TpModelWorker
109
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
110
from sglang.srt.managers.utils import validate_input_length
111
from sglang.srt.mem_cache.chunk_cache import ChunkCache
112
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
113
from sglang.srt.mem_cache.radix_cache import RadixCache
114
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
Lianmin Zheng's avatar
Lianmin Zheng committed
115
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
116
from sglang.srt.server_args import PortArgs, ServerArgs
117
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
118
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
119
from sglang.srt.utils import (
120
    DynamicGradMode,
121
122
    broadcast_pyobj,
    configure_logger,
123
    crash_on_warnings,
124
    get_bool_env_var,
125
    get_zmq_socket,
Lianmin Zheng's avatar
Lianmin Zheng committed
126
    kill_itself_when_parent_died,
127
    pyspy_dump_schedulers,
128
    set_gpu_proc_affinity,
129
130
131
    set_random_seed,
    suppress_other_loggers,
)
132
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
133

134
135
expert_distribution_recorder = ExpertDistributionRecorder()

136
137
logger = logging.getLogger(__name__)

138
# Test retract decode for debugging purposes
139
140
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
141

142

143
144
145
146
@dataclass
class GenerationBatchResult:
    logits_output: LogitsProcessorOutput
    next_token_ids: List[int]
147
148
    extend_input_len_per_req: List[int]
    extend_logprob_start_len_per_req: List[int]
149
150
151
152
153
154
155
156
157
    bid: int


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


Byron Hsu's avatar
Byron Hsu committed
158
159
160
161
162
class Scheduler(
    SchedulerOutputProcessorMixin,
    SchedulerDisaggregationDecodeMixin,
    SchedulerDisaggregationPrefillMixin,
):
163
164
165
166
167
168
169
170
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
171
        dp_rank: Optional[int],
172
173
    ):
        # Parse args
174
        self.server_args = server_args
175
176
        self.tp_rank = tp_rank
        self.tp_size = server_args.tp_size
177
178
179
        self.schedule_policy = server_args.schedule_policy
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
180
        self.enable_overlap = not server_args.disable_overlap_schedule
181
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
182
        self.enable_metrics = server_args.enable_metrics
183
        self.stream_interval = server_args.stream_interval
184
185
186
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
187
188
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
Lianmin Zheng's avatar
Lianmin Zheng committed
189
        self.page_size = server_args.page_size
190

191
        # Distributed rank info
192
193
194
195
196
197
198
199
200
201
        self.dp_size = server_args.dp_size
        self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

202
203
        # Init inter-process communication
        context = zmq.Context(2)
204
        if self.attn_tp_rank == 0:
205
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
206
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
207
            )
208
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
209
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
210
            )
211

212
            if server_args.skip_tokenizer_init:
213
                # Directly send to the TokenizerManager
214
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
215
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
216
217
                )
            else:
218
                # Send to the DetokenizerManager
219
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
220
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
221
                )
222
223
224
225

            self.recv_from_rpc = get_zmq_socket(
                context, zmq.DEALER, port_args.rpc_ipc_name, False
            )
226
        else:
227
            self.recv_from_tokenizer = None
228
            self.recv_from_rpc = None
229
230
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
231
232

        # Init tokenizer
233
        self.init_tokenizer()
234

235
236
237
238
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
239
240
241
242
        if self.model_config.is_multimodal:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for multimodal models.")

243
        # Launch a tensor parallel worker
244
        if self.enable_overlap:
245
            TpWorkerClass = TpModelWorkerClient
246
247
        else:
            TpWorkerClass = TpModelWorker
248

249
        self.tp_worker = TpWorkerClass(
250
            server_args=server_args,
251
252
            gpu_id=gpu_id,
            tp_rank=tp_rank,
253
            dp_rank=dp_rank,
254
            nccl_port=port_args.nccl_port,
255
        )
256

257
        # Launch a draft worker for speculative decoding
258
259
260
261
262
263
264
265
266
267
268
269
270
271
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
        else:
            self.draft_worker = None

272
        # Get token and memory info from the model worker
273
274
275
276
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
277
            self.max_req_len,
278
279
            self.max_req_input_len,
            self.random_seed,
280
            self.device,
281
282
283
284
285
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
286
        self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
287
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
288
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
289
        global_server_args_dict.update(worker_global_server_args_dict)
290
        set_random_seed(self.random_seed)
291

292
293
294
        # Print debug info
        logger.info(
            f"max_total_num_tokens={self.max_total_num_tokens}, "
295
            f"chunked_prefill_size={server_args.chunked_prefill_size}, "
296
297
298
299
300
            f"max_prefill_tokens={self.max_prefill_tokens}, "
            f"max_running_requests={self.max_running_requests}, "
            f"context_len={self.model_config.context_len}"
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
301
        # Init memory pool and cache
302
        self.init_memory_pool_and_cache()
303
304
305

        # Init running status
        self.waiting_queue: List[Req] = []
306
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
307
        self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
308
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
309
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
310
        # The last forward batch
311
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
312
313
        self.forward_ct = 0
        self.forward_ct_decode = 0
314
        self.num_generated_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
315
        self.num_prefill_tokens = 0
316
        self.last_decode_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
317
        self.last_prefill_stats_tic = time.time()
318
        self.return_health_check_ct = 0
319
        self.current_stream = torch.get_device_module(self.device).current_stream()
320
321
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
322

323
        # Init session info
324
        self.sessions: Dict[str, Session] = {}
325
326
327

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
328
329
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
330
        self.chunked_req = None
331
332
333
334
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
335
        # Init the grammar backend for constrained generation
336
        self.grammar_queue: List[Req] = []
337
        if not server_args.skip_tokenizer_init:
338
339
340
            self.grammar_backend = create_grammar_backend(
                server_args, self.tokenizer, self.model_config.vocab_size
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
341
342
        else:
            self.grammar_backend = None
343

344
        # Init schedule policy and new token estimation
345
        self.policy = SchedulePolicy(
Lianmin Zheng's avatar
Lianmin Zheng committed
346
347
348
            self.schedule_policy,
            self.tree_cache,
            self.enable_hierarchical_cache,
349
        )
350
351
352
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
353
354
        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
355
356
            * server_args.schedule_conservativeness,
            1.0,
357
        )
358
359
360
361
362
363
364
365
366
367
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
368
369
370
371
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
372
        self.parent_process = psutil.Process().parent()
Lianmin Zheng's avatar
Lianmin Zheng committed
373

374
        # Init memory saver
375
376
377
378
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )

379
        # Init profiler
380
381
382
383
        self.torch_profiler = None
        self.torch_profiler_output_dir: Optional[str] = None
        self.torch_profiler_activities: Optional[List[str]] = None
        self.profiler_target_forward_ct: Optional[int] = None
384

385
        # Init metrics stats
386
        self.init_metrics()
387

388
389
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
390
391
392
393
394
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
                (FlushCacheReq, self.flush_cache_wrapped),
                (AbortReq, self.abort_request),
395
396
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
397
398
399
400
401
402
403
404
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
405
406
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
407
                (ProfileReq, self.profile),
408
                (GetInternalStateReq, self.get_internal_state),
409
                (SetInternalStateReq, self.set_internal_state),
410
                (RpcReqInput, self.handle_rpc_request),
411
                (ExpertDistributionReq, self.expert_distribution_handle),
412
413
414
            ]
        )

Byron Hsu's avatar
Byron Hsu committed
415
416
417
418
419
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
        self.init_disaggregation()

420
421
    def init_tokenizer(self):
        server_args = self.server_args
Lianmin Zheng's avatar
Lianmin Zheng committed
422

423
424
425
426
427
428
429
430
431
432
433
        self.model_config = ModelConfig(
            server_args.model_path,
            trust_remote_code=server_args.trust_remote_code,
            revision=server_args.revision,
            context_length=server_args.context_length,
            model_override_args=server_args.json_model_override_args,
            is_embedding=server_args.is_embedding,
            dtype=server_args.dtype,
            quantization=server_args.quantization,
        )
        self.is_generation = self.model_config.is_generation
434

435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if self.model_config.is_multimodal:
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )
                self.tokenizer = self.processor.tokenizer
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )

    def init_memory_pool_and_cache(self):
        server_args = self.server_args

        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            self.tp_worker.get_memory_pool()
        )

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
            )
        else:
            if self.enable_hierarchical_cache:
                self.tree_cache = HiRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
474
                    tp_cache_group=self.tp_worker.get_tp_cpu_group(),
475
                    page_size=self.page_size,
476
                    hicache_ratio=server_args.hicache_ratio,
477
478
479
480
481
                )
            else:
                self.tree_cache = RadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
482
                    page_size=self.page_size,
483
484
485
486
487
488
489
490
491
492
493
494
495
                    disable=server_args.disable_radix_cache,
                )

        self.decode_mem_cache_buf_multiplier = (
            1
            if self.spec_algorithm.is_none()
            else (
                server_args.speculative_num_draft_tokens
                + (
                    server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                )
            )
496
        )
497
498
499
500
501
502
503

    def init_metrics(self):
        # The largest prefill length of a single request
        self._largest_prefill_len: int = 0
        # The largest context length (prefill + generation) of a single request
        self._largest_prefill_decode_len: int = 0
        self.last_gen_throughput: float = 0.0
Lianmin Zheng's avatar
Lianmin Zheng committed
504
        self.last_input_throughput: float = 0.0
505
506
507
508
509
510
511
512
513
514
515
516
517
518
        self.step_time_dict = defaultdict(list)  # Dict[batch size -> step time]
        self.spec_num_total_accepted_tokens = 0
        self.spec_num_total_forward_ct = 0
        self.cum_spec_accept_length = 0
        self.cum_spec_accept_count = 0
        self.stats = SchedulerStats()
        if self.enable_metrics:
            engine_type = "unified"
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    "engine_type": engine_type,
                },
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
519

Byron Hsu's avatar
Byron Hsu committed
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
    def init_disaggregation(self):
        if (
            self.disaggregation_mode == DisaggregationMode.DECODE
        ):  # *2 for the headroom.
            buffer_size = (self.req_to_token_pool.size) * 2
            req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
                buffer_size
            )
            aux_dtype = torch.int32
            # A list of metadata buffers. The shape is (b, metadata_size) where
            # b corresponds to a max running requests. The last shape * dtype.itemsize
            # should be larger than 64 bytes to work with RDMA, so we pad it.
            output_id_buffer = torch.zeros(
                (buffer_size, 16), dtype=aux_dtype, device="cpu"
            )
            metadata_buffers = [output_id_buffer]

            # The decode requests polling kv cache
            self.disagg_decode_transfer_queue = DecodeTransferQueue(
                gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
                metadata_buffers=metadata_buffers,
            )

            # The decode requests pending for pre-allocation
            self.disagg_decode_prealloc_queue = DecodePreallocQueue(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
                metadata_buffers=metadata_buffers,
                aux_dtype=aux_dtype,
                scheduler=self,
                transfer_queue=self.disagg_decode_transfer_queue,
                tree_cache=self.tree_cache,
                gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
            )
        elif self.disaggregation_mode == DisaggregationMode.PREFILL:
            # *2 for the headroom.
            buffer_size = self.max_running_requests * 2
            req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
                buffer_size
            )
            aux_dtype = torch.int32
            # A list of metadata buffers. The shape is (b, metadata_size) where
            # b corresponds to a max running requests. The last shape * dtype.itemsize
            # should be larger than 64 bytes to work with RDMA, so we pad it.
            output_id_buffer = torch.zeros(
                (buffer_size, 16), dtype=aux_dtype, device="cpu"
            )
            metadata_buffers = [output_id_buffer]

            self.disagg_prefill_pending_queue = PrefillBootstrapQueue(
                token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
                metadata_buffers=metadata_buffers,
                aux_dtype=aux_dtype,
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
                gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
            )
            # The prefill requests that are in the middle of kv sending
            self.disagg_prefill_infight_queue: List[Req] = []

587
    @DynamicGradMode()
588
    def event_loop_normal(self):
589
        """A normal scheduler loop."""
590
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
591
592
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
593

594
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
595
            self.cur_batch = batch
596
597
598
599

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
600
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
601
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
602
                self.check_memory()
603
                self.new_token_ratio = self.init_new_token_ratio
604
605

            self.last_batch = batch
606

607
    @DynamicGradMode()
Lianmin Zheng's avatar
Lianmin Zheng committed
608
    def event_loop_overlap(self):
609
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
610
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
611
612
613
614
615
616
617

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
618

Lianmin Zheng's avatar
Lianmin Zheng committed
619
620
            if batch:
                result = self.run_batch(batch)
621
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
622

623
                if self.last_batch is None:
624
                    # Create a dummy first batch to start the pipeline for overlap schedule.
625
626
627
628
629
630
631
632
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
                    self.process_batch_result(tmp_batch, None)

Lianmin Zheng's avatar
Lianmin Zheng committed
633
            if self.last_batch:
634
                # Process the results of the last batch
635
                tmp_batch, tmp_result = self.result_queue.popleft()
636
637
638
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
639
640
                self.process_batch_result(tmp_batch, tmp_result)
            elif batch is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
641
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
642
                self.check_memory()
643
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
644
645
646

            self.last_batch = batch

Byron Hsu's avatar
Byron Hsu committed
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
    @torch.no_grad()
    def event_loop_normal_disagg_prefill(self):
        """A normal scheduler loop for prefill worker in disaggregation mode."""

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
            self.waiting_queue.extend(
                self.disagg_prefill_pending_queue.pop_bootstrapped()
            )
            self.process_prefill_chunk()
            batch = self.get_new_batch_prefill()
            self.cur_batch = batch

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result_disagg_prefill(batch, result)

            if len(self.disagg_prefill_infight_queue) > 0:
                self.process_disagg_prefill_infight_queue()

            if batch is None and len(self.disagg_prefill_infight_queue) == 0:
                self.check_memory()
                self.new_token_ratio = self.init_new_token_ratio

            self.last_batch = batch
            # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
            # Otherwise, it hangs under high concurrency
            self.running_batch.batch_is_full = False

    @torch.no_grad()
    def event_loop_normal_disagg_decode(self):
        """A normal scheduler loop for decode worker in disaggregation mode."""

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
            # polling and allocating kv cache
            self.process_decode_queue()
            batch = self.get_next_disagg_decode_batch_to_run()
            self.cur_batch = batch

            if batch:
                # Generate fake extend output.
                if batch.forward_mode.is_extend():
                    # Note: Logprobs should be handled on the prefill engine.
                    self.stream_output(
                        batch.reqs, [False for _ in range(len(batch.reqs))]
                    )
                else:
                    result = self.run_batch(batch)
                    self.process_batch_result(batch, result)

            if batch is None and (
                len(self.disagg_decode_transfer_queue.queue)
                + len(self.disagg_decode_prealloc_queue.queue)
                == 0
            ):
                # When the server is idle, do self-check and re-init some states
                self.check_memory()
                self.new_token_ratio = self.init_new_token_ratio

            self.last_batch = batch

711
712
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
713
        if self.attn_tp_rank == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
714
715
            recv_reqs = []

716
717
718
719
720
            while True:
                try:
                    recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
Lianmin Zheng's avatar
Lianmin Zheng committed
721
                recv_reqs.append(recv_req)
722
723
724
725
726
727
728

            while True:
                try:
                    recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
                except zmq.ZMQError:
                    break
                recv_reqs.append(recv_rpc)
Lianmin Zheng's avatar
Lianmin Zheng committed
729
730
        else:
            recv_reqs = None
731

732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
                work_reqs = broadcast_pyobj(
                    work_reqs,
                    self.attn_tp_rank,
                    self.attn_tp_cpu_group,
                    src=attn_tp_rank_0,
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
                    control_reqs, self.tp_rank, self.tp_cpu_group
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
766
            recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
767
768
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
769
    def process_input_requests(self, recv_reqs: List):
770
        for recv_req in recv_reqs:
771
772
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
Lianmin Zheng's avatar
Lianmin Zheng committed
773
                self.chunked_req is not None or not self.running_batch.is_empty()
774
775
776
777
            ):
                self.return_health_check_ct += 1
                continue

778
            output = self._request_dispatcher(recv_req)
779
            if output is not None:
780
781
782
783
784
                if isinstance(output, RpcReqOutput):
                    if self.recv_from_rpc is not None:
                        self.recv_from_rpc.send_pyobj(output)
                else:
                    self.send_to_tokenizer.send_pyobj(output)
785
786
787
788
789

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
790
        # Create a new request
791
792
793
794
795
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
796
797
798
799
800
801
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

802
803
804
805
806
807
808
809
810
811
812
813
814
            # Handle custom logit processor passed to the request
            custom_logit_processor = recv_req.custom_logit_processor
            if (
                not self.server_args.enable_custom_logit_processor
                and custom_logit_processor is not None
            ):
                logger.warning(
                    "The SGLang server is not configured to enable custom logit processor."
                    "The custom logit processor passed in will be ignored."
                    "Please set --enable-custom-logits-processor to enable this feature."
                )
                custom_logit_processor = None

815
816
817
818
819
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
820
821
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
822
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
823
                stream=recv_req.stream,
824
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
825
                input_embeds=recv_req.input_embeds,
826
                custom_logit_processor=custom_logit_processor,
827
                return_hidden_states=recv_req.return_hidden_states,
828
                eos_token_ids=self.model_config.hf_eos_token_id,
829
830
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
831

832
833
834
835
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
836
                req.finished_reason = FINISH_ABORT(
837
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
838
                )
839
                self._add_request_to_queue(req)
840
841
                return
        else:
842
843
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
844
            req = session.create_req(recv_req, self.tokenizer)
845
            if isinstance(req.finished_reason, FINISH_ABORT):
846
                self._add_request_to_queue(req)
847
                return
848

849
        # Handle multimodal inputs
Mick's avatar
Mick committed
850
851
        if recv_req.mm_inputs is not None:
            image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
852
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
853
            req.origin_input_ids = self.pad_input_ids_func(
854
                req.origin_input_ids, image_inputs
855
            )
856
            req.extend_image_inputs(image_inputs)
857

858
            if len(req.origin_input_ids) >= self.max_req_input_len:
859
                error_msg = (
860
                    "Multimodal prompt is too long after expanding multimodal tokens. "
861
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
862
                )
863
                logger.error(error_msg)
864
                req.origin_input_ids = [0]
Mick's avatar
Mick committed
865
                req.multimodal_inputs = None
866
                req.sampling_params.max_new_tokens = 0
867
                req.finished_reason = FINISH_ABORT(
868
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
869
                )
870
                self._add_request_to_queue(req)
871
872
                return

873
874
875
876
877
878
879
        # Validate prompts length
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
880
881
            req.origin_input_ids = [0]
            req.sampling_params.max_new_tokens = 0
882
            self._add_request_to_queue(req)
883
            return
884

885
        # Copy more attributes
886
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
887
888
889
890
891
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(req.origin_input_ids) - 1
        else:
            req.logprob_start_len = recv_req.logprob_start_len

892
893
894
895
896
897
898
899
900
901
        if req.logprob_start_len >= len(req.origin_input_ids):
            req.finished_reason = FINISH_ABORT(
                f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
                HTTPStatus.BAD_REQUEST,
                "BadRequestError",
            )
            req.logprob_start_len = len(req.origin_input_ids) - 1
            self._add_request_to_queue(req)
            return

902
903
904
905
906
907
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
908
            self.max_req_len - len(req.origin_input_ids) - 1,
909
910
        )

911
912
913
914
915
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
916
            or req.sampling_params.ebnf is not None
917
            or req.sampling_params.structural_tag is not None
918
919
920
921
922
923
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
924
925
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
926
927
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
928
929
930
931
932
933
934

            req.grammar = self.grammar_backend.get_cached_value(key)
            if not req.grammar:
                req.grammar = self.grammar_backend.get_future_value(key)
                add_to_grammar_queue = True

        if add_to_grammar_queue:
935
936
            self.grammar_queue.append(req)
        else:
937
938
939
            self._add_request_to_queue(req)

    def _add_request_to_queue(self, req: Req):
Byron Hsu's avatar
Byron Hsu committed
940
941
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            self.disagg_prefill_pending_queue.add(req)
942

Byron Hsu's avatar
Byron Hsu committed
943
944
945
946
947
948
949
950
951
952
953
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.add(req)

        else:
            self.waiting_queue.append(req)

    def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
        if self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.extend(reqs)
        else:
            self.waiting_queue.extend(reqs)
954
955
956

    def handle_embedding_request(
        self,
957
        recv_req: TokenizedEmbeddingReqInput,
958
959
960
961
962
963
964
965
966
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

967
968
        # Handle multimodal inputs
        if recv_req.image_inputs is not None:
Mick's avatar
Mick committed
969
            image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
970
971
972
973
974
975
976
977
978
979
980
981
982
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
            req.origin_input_ids = self.pad_input_ids_func(
                req.origin_input_ids, image_inputs
            )
            req.extend_image_inputs(image_inputs)

            if len(req.origin_input_ids) >= self.max_req_input_len:
                error_msg = (
                    "Multimodal prompt is too long after expanding multimodal tokens. "
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                )
                logger.error(error_msg)
                req.origin_input_ids = [0]
Mick's avatar
Mick committed
983
                req.multimodal_inputs = None
984
985
986
987
988
989
990
                req.sampling_params.max_new_tokens = 0
                req.finished_reason = FINISH_ABORT(
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
                )
                self.waiting_queue.append(req)
                return

991
        # Validate prompts length
992
        error_msg = validate_input_length(
993
994
995
996
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
997
        if error_msg:
998
            self._add_request_to_queue(req)
999
            return
1000

1001
1002
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
1003
        self._add_request_to_queue(req)
1004

1005
1006
1007
1008
    def log_prefill_stats(
        self,
        adder: PrefillAdder,
        can_run_list: List[Req],
1009
        running_bs: int,
1010
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1011
1012
1013
1014
1015
        gap_latency = time.time() - self.last_prefill_stats_tic
        self.last_prefill_stats_tic = time.time()
        self.last_input_throughput = self.num_prefill_tokens / gap_latency
        self.num_prefill_tokens = 0

1016
        num_used = self.max_total_num_tokens - (
1017
1018
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
1019
        )
1020
1021
1022
        self._largest_prefill_len = max(
            self._largest_prefill_len, adder.log_input_tokens
        )
1023

1024
        f = (
1025
1026
1027
1028
1029
1030
            f"Prefill batch. "
            f"#new-seq: {len(can_run_list)}, "
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
1031
            f"#queue-req: {len(self.waiting_queue)}, "
1032
        )
1033
        logger.info(f)
1034
1035

        if self.enable_metrics:
1036
1037
1038
            cache_hit_rate = adder.log_hit_tokens / (
                adder.log_input_tokens + adder.log_hit_tokens
            )
1039
1040
1041
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
1042
1043
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.stats.cache_hit_rate = cache_hit_rate
1044
1045
1046
            self.metrics_collector.log_stats(self.stats)

    def log_decode_stats(self):
1047
1048
1049
1050
        gap_latency = time.time() - self.last_decode_stats_tic
        self.last_decode_stats_tic = time.time()
        self.last_gen_throughput = self.num_generated_tokens / gap_latency
        self.num_generated_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1051
        num_running_reqs = len(self.running_batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1052
        num_used = self.max_total_num_tokens - (
1053
1054
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1055
        )
1056
1057
1058
1059
1060

        if RECORD_STEP_TIME:
            self.step_time_dict[num_running_reqs].append(
                gap_latency / self.server_args.decode_log_interval
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1061

1062
1063
1064
1065
1066
1067
        if self.spec_algorithm.is_none():
            msg = (
                f"Decode batch. "
                f"#running-req: {num_running_reqs}, "
                f"#token: {num_used}, "
                f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
1068
1069
                f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
                f"#queue-req: {len(self.waiting_queue)}, "
1070
            )
1071
            spec_accept_length = 0
1072
        else:
1073
            spec_accept_length = (
1074
1075
                self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
            )
1076
1077
            self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
            self.cum_spec_accept_count += self.spec_num_total_forward_ct
1078
1079
1080
1081
1082
1083
            self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
            msg = (
                f"Decode batch. "
                f"#running-req: {num_running_reqs}, "
                f"#token: {num_used}, "
                f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
1084
                f"accept len: {spec_accept_length:.2f}, "
1085
1086
                f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
                f"#queue-req: {len(self.waiting_queue)}, "
1087
1088
1089
            )

        logger.info(msg)
1090
1091
1092
1093
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
1094
1095
            self.stats.cache_hit_rate = 0.0
            self.stats.gen_throughput = self.last_gen_throughput
1096
            self.stats.num_queue_reqs = len(self.waiting_queue)
1097
            self.stats.spec_accept_length = spec_accept_length
1098
1099
            self.metrics_collector.log_stats(self.stats)

Lianmin Zheng's avatar
Lianmin Zheng committed
1100
1101
    def check_memory(self):
        available_size = (
1102
1103
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1104
        )
1105
1106
1107
1108
1109
1110
1111
        protected_size = self.tree_cache.protected_size()
        memory_leak = available_size != (
            self.max_total_num_tokens
            if not self.enable_hierarchical_cache
            else self.max_total_num_tokens - protected_size
        )
        if memory_leak:
1112
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
1113
                "KV cache pool leak detected! "
1114
                f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1115
1116
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1117
            )
1118
1119
1120
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1121
1122

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
1123
            msg = (
Lianmin Zheng's avatar
Lianmin Zheng committed
1124
                "Memory pool leak detected!"
1125
1126
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1127
            )
1128
1129
1130
            warnings.warn(msg)
            if crash_on_warnings():
                raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1131

1132
1133
1134
1135
1136
1137
1138
        if (
            self.enable_metrics
            and self.attn_tp_rank == 0
            and time.time() > self.metrics_collector.last_log_time + 30
        ):
            # During idle time, also collect metrics every 30 seconds.
            num_used = self.max_total_num_tokens - (
1139
                self.token_to_kv_pool_allocator.available_size()
1140
1141
                + self.tree_cache.evictable_size()
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1142
            num_running_reqs = len(self.running_batch.reqs)
1143
1144
1145
1146
1147
1148
1149
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.metrics_collector.log_stats(self.stats)

1150
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1151
        # Merge the prefill batch into the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1152
        if self.last_batch and self.last_batch.forward_mode.is_extend():
1153
1154
1155
1156
1157
1158
1159
            if self.chunked_req:
                # Move the chunked request out of the batch so that we can merge
                # only finished requests to running_batch.
                self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
                self.tree_cache.cache_unfinished_req(self.chunked_req)
                # chunked request keeps its rid but will get a new req_pool_idx
                self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
Lianmin Zheng's avatar
Lianmin Zheng committed
1160
                self.running_batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1161

1162
            # Filter batch
1163
            last_bs = self.last_batch.batch_size()
1164
            self.last_batch.filter_batch()
1165
            if self.last_batch.batch_size() < last_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1166
                self.running_batch.batch_is_full = False
1167

1168
            # Merge the new batch into the running batch
1169
            if not self.last_batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1170
                if self.running_batch.is_empty():
1171
1172
                    self.running_batch = self.last_batch
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1173
                    # Merge running_batch with prefill batch
1174
                    self.running_batch.merge_batch(self.last_batch)
1175

1176
1177
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
1178
1179
1180
1181
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
1182
            if not self.running_batch.is_empty():
1183
                self.running_batch = self.update_running_batch(self.running_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1184
1185
1186
                ret = self.running_batch if not self.running_batch.is_empty() else None
            else:
                ret = None
1187

1188
        # Handle DP attention
1189
        if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
Lianmin Zheng's avatar
Lianmin Zheng committed
1190
            ret, _ = self.prepare_dp_attn_batch(ret)
1191
1192

        return ret
1193

Lianmin Zheng's avatar
Lianmin Zheng committed
1194
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
1195
        # Check if the grammar is ready in the grammar queue
1196
        if self.grammar_queue:
1197
            self.move_ready_grammar_requests()
1198

Lianmin Zheng's avatar
Lianmin Zheng committed
1199
1200
        # Handle the cases where prefill is not allowed
        if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1201
            self.running_batch.batch_is_full or len(self.waiting_queue) == 0
1202
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1203
1204
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1205
        running_bs = len(self.running_batch.reqs)
1206
        if running_bs >= self.max_running_requests:
Lianmin Zheng's avatar
Lianmin Zheng committed
1207
            self.running_batch.batch_is_full = True
1208
1209
            return None

1210
1211
1212
1213
1214
        if self.enable_hierarchical_cache:
            # check for completion of hierarchical cache activities to release memory
            self.tree_cache.writing_check()
            self.tree_cache.loading_check()

1215
1216
1217
        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
1218
        # Prefill policy
1219
1220
        adder = PrefillAdder(
            self.tree_cache,
1221
            self.token_to_kv_pool_allocator,
1222
1223
1224
1225
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1226
            running_bs if self.is_mixed_chunk else 0,
1227
1228
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1229
        if self.chunked_req is not None:
1230
1231
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1232

Lianmin Zheng's avatar
Lianmin Zheng committed
1233
        if self.lora_paths:
Lianmin Zheng's avatar
Lianmin Zheng committed
1234
1235
            lora_set = set([req.lora_path for req in self.running_batch.reqs])

1236
        # Get requests from the waiting queue to a new prefill batch
1237
1238
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1239
                self.lora_paths
1240
1241
1242
1243
1244
1245
1246
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1247
                self.running_batch.batch_is_full = True
1248
1249
                break

1250
            if running_bs + len(adder.can_run_list) >= self.max_running_requests:
Lianmin Zheng's avatar
Lianmin Zheng committed
1251
                self.running_batch.batch_is_full = True
1252
                break
1253

1254
1255
1256
1257
            req.init_next_round_input(
                None if prefix_computed else self.tree_cache,
                self.enable_hierarchical_cache,
            )
1258

1259
1260
1261
            res = adder.add_one_req(
                req, self.chunked_req, self.enable_hierarchical_cache
            )
1262
1263
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1264
1265
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
Lianmin Zheng's avatar
Lianmin Zheng committed
1266
1267
1268
                        self.running_batch.batch_is_full = len(
                            adder.can_run_list
                        ) > 0 or (
1269
1270
1271
1272
                            self.running_batch is not None
                            and not self.running_batch.is_empty()
                        )
                    else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1273
                        self.running_batch.batch_is_full = True
1274
1275
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1276
        # Update waiting queue
1277
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1278
1279
1280
1281
1282
        if len(can_run_list) == 0:
            return None
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1283

1284
1285
1286
        if self.enable_hierarchical_cache:
            self.tree_cache.read_to_load_cache()

1287
1288
1289
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1290

1291
1292
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1293

1294
        # Print stats
1295
        if self.attn_tp_rank == 0:
1296
            self.log_prefill_stats(adder, can_run_list, running_bs)
1297

Lianmin Zheng's avatar
Lianmin Zheng committed
1298
        # Create a new batch
1299
1300
1301
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
1302
            self.token_to_kv_pool_allocator,
1303
            self.tree_cache,
1304
            self.model_config,
1305
            self.enable_overlap,
1306
            self.spec_algorithm,
1307
            self.server_args.enable_custom_logit_processor,
1308
        )
1309
        new_batch.prepare_for_extend()
1310

Lianmin Zheng's avatar
Lianmin Zheng committed
1311
        # Mixed-style chunked prefill
1312
1313
        if (
            self.is_mixed_chunk
Lianmin Zheng's avatar
Lianmin Zheng committed
1314
            and not self.running_batch.is_empty()
1315
1316
1317
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1318
1319
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1320
                self.running_batch.prepare_for_decode()
1321
1322
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1323
1324
1325
            self.running_batch = ScheduleBatch(
                reqs=[], batch_is_full=self.running_batch.batch_is_full
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1326
1327
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1328
1329
1330

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1331
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1332
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1333
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1334

1335
1336
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1337
1338
            batch.batch_is_full = False
            return batch
1339

Lianmin Zheng's avatar
Lianmin Zheng committed
1340
        # Check if decode out of memory
1341
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1342
            TEST_RETRACT and batch.batch_size() > 10
1343
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1344
1345
            old_ratio = self.new_token_ratio

1346
            retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
Lianmin Zheng's avatar
Lianmin Zheng committed
1347
            self.new_token_ratio = new_token_ratio
1348

Lianmin Zheng's avatar
Lianmin Zheng committed
1349
1350
1351
1352
1353
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
1354
            self._extend_requests_to_queue(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1355
1356
        else:
            self.new_token_ratio = max(
1357
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1358
1359
1360
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1361
        if batch.batch_size() < initial_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1362
            batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1363
1364

        # Update batch tensors
1365
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1366
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1367

1368
1369
1370
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1371
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1372
1373
        self.forward_ct += 1

1374
1375
1376
1377
1378
1379
1380
        # Check profiler
        if (
            self.profiler_target_forward_ct
            and self.profiler_target_forward_ct <= self.forward_ct
        ):
            self.stop_profile()

1381
        # Run forward
1382
        if self.is_generation:
1383
1384
1385
1386
1387
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
                logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
                    model_worker_batch
                )
1388
                bid = model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
1389
            else:
1390
1391
1392
                (
                    logits_output,
                    next_token_ids,
1393
                    bid,
1394
1395
1396
1397
1398
1399
1400
                    num_accepted_tokens,
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
                self.spec_num_total_accepted_tokens += (
                    num_accepted_tokens + batch.batch_size()
                )
                self.spec_num_total_forward_ct += batch.batch_size()
                self.num_generated_tokens += num_accepted_tokens
1401
            batch.output_ids = next_token_ids
1402

1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
            if batch.return_logprob:
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_input_len_per_req = None
                extend_logprob_start_len_per_req = None

1415
1416
1417
            ret = GenerationBatchResult(
                logits_output=logits_output,
                next_token_ids=next_token_ids,
1418
1419
                extend_input_len_per_req=extend_input_len_per_req,
                extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1420
                bid=bid,
1421
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1422
1423
1424
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1425
1426
1427
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1428
        return ret
Chayenne's avatar
Chayenne committed
1429

1430
1431
1432
1433
1434
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1435
1436
        if batch.forward_mode.is_decode():
            self.process_batch_result_decode(batch, result)
1437
        elif batch.forward_mode.is_extend():
Lianmin Zheng's avatar
Lianmin Zheng committed
1438
            self.process_batch_result_prefill(batch, result)
1439
1440
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1441
                self.tp_worker.resolve_batch_result(result.bid)
1442
1443
1444
1445
                if batch.next_batch_sampling_info:
                    batch.next_batch_sampling_info.update_regex_vocab_mask()
                    self.current_stream.synchronize()
                    batch.next_batch_sampling_info.sampling_info_done.set()
1446
1447
        elif batch.forward_mode.is_dummy_first():
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1448
            self.current_stream.synchronize()
1449
            batch.next_batch_sampling_info.sampling_info_done.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
1450

1451
1452
1453
1454
1455
1456
1457
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

1458
1459
1460
1461
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1462
            global_num_tokens_for_logprob = 0
1463
1464
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1465
1466
1467
            if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle():
                num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens
            global_num_tokens_for_logprob = num_tokens
1468
1469
        else:
            num_tokens = local_batch.extend_num_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
            global_num_tokens_for_logprob = sum(
                [
                    # We should have at least 1 token for sample in every case.
                    max(extend_len - logprob_start_len, 1)
                    for logprob_start_len, extend_len in zip(
                        local_batch.extend_logprob_start_lens, local_batch.extend_lens
                    )
                ]
            )

        if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
            can_cuda_graph = 1
        else:
            can_cuda_graph = 0

        if not self.spec_algorithm.is_none():
            # TODO(sang): Support cuda graph when idle batch is there.
            if local_batch is None or local_batch.forward_mode.is_idle():
                can_cuda_graph = 0
1489

Lianmin Zheng's avatar
Lianmin Zheng committed
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
        is_extend_in_batch = (
            local_batch.forward_mode.is_extend() if local_batch else False
        )
        local_info = torch.tensor(
            [
                num_tokens,
                can_cuda_graph,
                global_num_tokens_for_logprob,
                is_extend_in_batch,
            ],
            dtype=torch.int64,
        )
        global_info = torch.empty(
            (self.server_args.dp_size, self.attn_tp_size, 4),
            dtype=torch.int64,
        )
1506
        torch.distributed.all_gather_into_tensor(
Lianmin Zheng's avatar
Lianmin Zheng committed
1507
1508
            global_info.flatten(),
            local_info,
1509
1510
            group=self.tp_cpu_group,
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1511
1512
1513
1514
        global_num_tokens = global_info[:, 0, 0].tolist()
        can_cuda_graph = min(global_info[:, 0, 1].tolist())
        global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
        is_extend_in_batch = global_info[:, 0, 3].tolist()
1515

Lianmin Zheng's avatar
Lianmin Zheng committed
1516
        if local_batch is None and max(global_num_tokens) > 0:
1517
1518
1519
            local_batch = self.get_idle_batch()

        if local_batch is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1520
1521
            local_batch.global_num_tokens = global_num_tokens
            local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
1522
1523
1524

            # Check forward mode for cuda graph
            if not self.server_args.disable_cuda_graph:
Lianmin Zheng's avatar
Lianmin Zheng committed
1525
                local_batch.can_run_dp_cuda_graph = can_cuda_graph
1526

Lianmin Zheng's avatar
Lianmin Zheng committed
1527
        return local_batch, any(is_extend_in_batch)
1528
1529
1530
1531
1532

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
1533
            self.token_to_kv_pool_allocator,
1534
1535
1536
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
1537
            self.spec_algorithm,
1538
            self.server_args.enable_custom_logit_processor,
1539
1540
1541
1542
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
        num_ready_reqs = 0
        for req in self.grammar_queue:
            try:
                req.grammar = req.grammar.result(timeout=0.05)
                num_ready_reqs += 1
            except futures._base.TimeoutError:
                break

1553
        if self.server_args.enable_dp_attention:
1554
1555
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
1556
        else:
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
            tensor = torch.tensor(num_ready_reqs, dtype=torch.int32)
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
            num_ready_reqs_max = tensor.item()
            for i in range(num_ready_reqs, num_ready_reqs_max):
                self.grammar_queue[i].grammar = self.grammar_queue[i].grammar.result()
            num_ready_reqs = num_ready_reqs_max
1570

1571
        self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1572
1573
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
    def watchdog_thread(self):
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
        self.watchdog_last_forward_ct = 0
        self.watchdog_last_time = time.time()

        while True:
            current = time.time()
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if current > self.watchdog_last_time + self.watchdog_timeout:
                        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)

        # Print batch size and memory pool info to check whether there are de-sync issues.
        logger.error(
            f"{self.cur_batch.batch_size()=}, "
            f"{self.cur_batch.reqs=}, "
            f"{self.token_to_kv_pool_allocator.available_size()=}, "
            f"{self.tree_cache.evictable_size()=}, "
        )
        # Wait for some time so that the parent process can print the error.
        pyspy_dump_schedulers()
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
        time.sleep(5)
        self.parent_process.send_signal(signal.SIGQUIT)

1605
1606
1607
    def flush_cache_wrapped(self, recv_req: FlushCacheReq):
        self.flush_cache()

1608
    def flush_cache(self):
1609
        """Flush the memory pool and cache."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1610
        if len(self.waiting_queue) == 0 and self.running_batch.is_empty():
1611
1612
            self.cur_batch = None
            self.last_batch = None
1613
            self.tree_cache.reset()
1614
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1615
                self.grammar_backend.reset()
1616
            self.req_to_token_pool.clear()
1617
            self.token_to_kv_pool_allocator.clear()
1618
1619
1620

            if not self.spec_algorithm.is_none():
                self.draft_worker.model_runner.req_to_token_pool.clear()
1621
                self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
1622
1623
1624
1625
1626

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
1627
1628
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
1629
1630
1631
1632
1633
1634
1635
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
1636
                f"#running-req: {len(self.running_batch.reqs)}"
1637
1638
1639
1640
            )
            if_success = False
        return if_success

1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )

        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
        return GetInternalStateReqOutput(
            internal_state=ret,
        )

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
            logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
    def handle_rpc_request(self, recv_req: RpcReqInput):
        # Handle RPC requests
        logger.info(
            f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
        )

        success = True
        exec = None
        try:
            func = getattr(self, recv_req.method)
            func(recv_req.parameters)
        except Exception as e:
            success = False
            exec = e
            logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")

        barrier()
        return RpcReqOutput(success, "" if not exec else str(exec))

    def save_remote_model(self, params):
        url = params["url"]

        if isinstance(self.tp_worker, TpModelWorkerClient):
            worker = self.tp_worker.worker
        else:
            worker = self.tp_worker

        worker.model_runner.save_remote_model(url)

    def save_sharded_model(self, params):
        if isinstance(self.tp_worker, TpModelWorkerClient):
            worker = self.tp_worker.worker
        else:
            worker = self.tp_worker

        worker.model_runner.save_sharded_model(
            path=params["path"],
            pattern=params["pattern"],
            max_size=params["max_size"],
        )

1725
1726
    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
Lianmin Zheng's avatar
Lianmin Zheng committed
1727
        to_del = []
1728
        for i, req in enumerate(self.waiting_queue):
Lianmin Zheng's avatar
Lianmin Zheng committed
1729
1730
            if req.rid.startswith(recv_req.rid):
                to_del.append(i)
1731
1732
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1733
1734
1735
        # Sort in reverse order to avoid index issues when deleting
        for i in sorted(to_del, reverse=True):
            req = self.waiting_queue.pop(i)
1736
1737
            logger.debug(f"Abort queued request. {req.rid=}")
            return
1738
1739

        # Delete requests in the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1740
1741
1742
1743
1744
        for req in self.running_batch.reqs:
            if req.rid.startswith(recv_req.rid) and not req.finished():
                logger.debug(f"Abort running request. {req.rid=}")
                req.to_abort = True
                return
1745

1746
1747
1748
    def _pause_engine(self) -> Tuple[List[Req], int]:
        raise NotImplementedError()

Chayenne's avatar
Chayenne committed
1749
1750
1751
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
1752
1753
1754
1755
1756
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
1757
        return UpdateWeightFromDiskReqOutput(success, message, 0)
1758

1759
1760
1761
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
1762
        return InitWeightsUpdateGroupReqOutput(success, message)
1763
1764

    def update_weights_from_distributed(
1765
1766
1767
        self,
        recv_req: UpdateWeightsFromDistributedReqInput,
    ) -> Tuple[bool, str]:
1768
1769
1770
1771
1772
1773
1774
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
1775
        return UpdateWeightsFromDistributedReqOutput(success, message)
1776

1777
1778
1779
1780
1781
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        """Update the online model parameter from tensors."""
        success, message = self.tp_worker.update_weights_from_tensor(recv_req)
        # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
        if success:
1782
1783
1784
            if recv_req.flush_cache:
                flash_cache_success = self.flush_cache()
                assert flash_cache_success, "Cache flush failed after updating weights"
1785
1786
        else:
            logger.error(message)
1787
        return UpdateWeightsFromTensorReqOutput(success, message)
1788

1789
1790
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
1791
        return GetWeightsByNameReqOutput(parameter)
1792

1793
    def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
1794
1795
1796
        self.memory_saver_adapter.check_validity(
            caller_name="release_memory_occupation"
        )
1797
1798
1799
1800
1801
        self.stashed_model_static_state = _export_static_state(
            self.tp_worker.worker.model_runner.model
        )
        self.memory_saver_adapter.pause()
        self.flush_cache()
1802
        return ReleaseMemoryOccupationReqOutput()
1803

1804
    def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
1805
        self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
1806
1807
1808
1809
1810
        self.memory_saver_adapter.resume()
        _import_static_state(
            self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
        )
        del self.stashed_model_static_state
1811
1812
1813
        return ResumeMemoryOccupationReqOutput()

    def profile(self, recv_req: ProfileReq):
1814
1815
1816
1817
        if recv_req.type == ProfileReqType.START_PROFILE:
            return self.start_profile(
                recv_req.output_dir, recv_req.num_steps, recv_req.activities
            )
1818
        else:
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
            return self.stop_profile()

    def start_profile(
        self,
        output_dir: Optional[str],
        num_steps: Optional[int],
        activities: Optional[List[str]],
    ) -> None:
        if self.torch_profiler_activities:
            return ProfileReqOutput(
                success=False,
                message="Profiling is already in progress. Call /stop_profile first.",
            )

        if output_dir is None:
            output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
        if activities is None:
            activities = ["CPU", "GPU"]

        self.torch_profiler_output_dir = output_dir
        self.torch_profiler_activities = activities
        logger.info(
            "Profiling starts. Traces will be saved to: %s",
            self.torch_profiler_output_dir,
        )

        activity_map = {
            "CPU": torch.profiler.ProfilerActivity.CPU,
            "GPU": torch.profiler.ProfilerActivity.CUDA,
        }
        torchprof_activities = [
            activity_map[a] for a in activities if a in activity_map
        ]

        if torchprof_activities:
            self.torch_profiler = torch.profiler.profile(
                activities=torchprof_activities,
                with_stack=True,
            )
            self.torch_profiler.start()

        if "MEM" in activities:
            torch.cuda.memory._record_memory_history(max_entries=100000)
1862

1863
1864
1865
1866
1867
1868
        if num_steps:
            self.profiler_target_forward_ct = self.forward_ct + num_steps
            # The caller will be notified when reaching profiler_target_forward_ct
        else:
            self.profiler_target_forward_ct = None
            return ProfileReqOutput(success=True, message="Succeeded")
1869
1870

    def stop_profile(self) -> None:
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
        if self.torch_profiler_activities is None:
            return

        logger.info("Stop profiling...")
        if self.torch_profiler is not None:
            self.torch_profiler.stop()
            self.torch_profiler.export_chrome_trace(
                os.path.join(
                    self.torch_profiler_output_dir,
                    str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
                )
            )

        if "MEM" in self.torch_profiler_activities:
            memory_profile_path = os.path.join(
                self.torch_profiler_trace_dir,
                str(time.time()) + f"-TP-{self.tp_rank}-memory" + ".pickle",
            )
            torch.cuda.memory._dump_snapshot(memory_profile_path)
            torch.cuda.memory._record_memory_history(enabled=None)

        logger.info(
            "Profiling done. Traces are saved to: %s",
            self.torch_profiler_output_dir,
1895
        )
1896
1897
1898
1899
1900
1901
1902
1903
        self.torch_profiler = None
        self.torch_profiler_output_dir = None
        self.torch_profiler_activities = None

        if self.profiler_target_forward_ct:
            self.send_to_tokenizer.send_pyobj(
                ProfileReqOutput(success=True, message="Succeeded.")
            )
1904

1905
1906
1907
1908
1909
1910
1911
1912
1913
    def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
        if recv_req == ExpertDistributionReq.START_RECORD:
            expert_distribution_recorder.start_record()
        elif recv_req == ExpertDistributionReq.STOP_RECORD:
            expert_distribution_recorder.stop_record()
        elif recv_req == ExpertDistributionReq.DUMP_RECORD:
            expert_distribution_recorder.dump_record()
        else:
            raise ValueError("Unrecognized ExpertDistributionReq value")
1914
        return ExpertDistributionReqOutput()
1915

1916
    def open_session(self, recv_req: OpenSessionReqInput):
1917
1918
1919
1920
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
1921
            return OpenSessionReqOutput(session_id, False)
1922
        elif session_id is None:
1923
            logger.warning("session id is None, cannot open.")
1924
            return OpenSessionReqOutput(session_id, False)
1925
1926
1927
1928
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
1929
            return OpenSessionReqOutput(session_id, True)
1930
1931
1932
1933
1934
1935
1936
1937
1938

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

1939

1940
1941
1942
1943
def is_health_check_generate_req(recv_req):
    return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")


1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
def _export_static_state(model):
    return dict(
        buffers=[
            (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
        ]
    )


def _import_static_state(model, static_params):
    self_named_buffers = dict(model.named_buffers())
    for name, tensor in static_params["buffers"]:
        self_named_buffers[name][...] = tensor


1958
1959
1960
1961
1962
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
1963
    dp_rank: Optional[int],
1964
    pipe_writer,
1965
):
1966
1967
1968
1969
1970
1971
1972

    # Generate the prefix
    if dp_rank is None:
        prefix = f" TP{tp_rank}"
    else:
        prefix = f" DP{dp_rank} TP{tp_rank}"

1973
    # Config the process
1974
    kill_itself_when_parent_died()
1975
    setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
1976
    faulthandler.enable()
1977
    parent_process = psutil.Process().parent()
1978

1979
1980
1981
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
1982

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
1983
    # Configure the logger
1984
    configure_logger(server_args, prefix=prefix)
1985
    suppress_other_loggers()
1986

1987
    # Set cpu affinity to this gpu process
1988
1989
1990
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

1991
    # Create a scheduler and run the event loop
1992
    try:
1993
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
1994
        pipe_writer.send(
Mick's avatar
Mick committed
1995
1996
1997
1998
1999
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2000
        )
Byron Hsu's avatar
Byron Hsu committed
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
        disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode

        if disaggregation_mode == DisaggregationMode.NULL:
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap()
            else:
                scheduler.event_loop_normal()
        elif disaggregation_mode == DisaggregationMode.PREFILL:
            scheduler.event_loop_normal_disagg_prefill()
        elif disaggregation_mode == DisaggregationMode.DECODE:
            scheduler.event_loop_normal_disagg_decode()

2013
    except Exception:
2014
2015
2016
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)