scheduler.py 89.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
import time
23
from collections import defaultdict, deque
Lianmin Zheng's avatar
Lianmin Zheng committed
24
from concurrent import futures
25
from dataclasses import dataclass
26
from http import HTTPStatus
27
from types import SimpleNamespace
28
from typing import Dict, List, Optional, Tuple, Union
29

30
import psutil
31
import setproctitle
32
import torch
33
import zmq
34
from torch.distributed import barrier
35

36
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
37
from sglang.srt.configs.model_config import ModelConfig
38
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
Byron Hsu's avatar
Byron Hsu committed
39
40
41
42
43
44
45
46
47
48
49
50
from sglang.srt.disaggregation.decode import (
    DecodePreallocQueue,
    DecodeTransferQueue,
    SchedulerDisaggregationDecodeMixin,
)
from sglang.srt.disaggregation.prefill import (
    PrefillBootstrapQueue,
    SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
    ReqToMetadataIdxAllocator,
51
    TransferBackend,
Byron Hsu's avatar
Byron Hsu committed
52
)
53
from sglang.srt.distributed import get_pp_group, get_world_group
xm:D's avatar
xm:D committed
54
55
56
57
58
from sglang.srt.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
59
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
60
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
61
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
62
63
from sglang.srt.managers.io_struct import (
    AbortReq,
64
    CloseSessionReqInput,
65
    ExpertDistributionReq,
66
    ExpertDistributionReqOutput,
67
68
    FlushCacheReqInput,
    FlushCacheReqOutput,
69
70
    GetInternalStateReq,
    GetInternalStateReqOutput,
71
72
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
73
    HealthCheckOutput,
74
75
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
76
77
    OpenSessionReqInput,
    OpenSessionReqOutput,
78
    ProfileReq,
79
80
    ProfileReqOutput,
    ProfileReqType,
81
82
83
84
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
85
86
    RpcReqInput,
    RpcReqOutput,
87
88
    SetInternalStateReq,
    SetInternalStateReqOutput,
89
90
    SlowDownReqInput,
    SlowDownReqOutput,
91
92
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
93
94
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
95
96
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
97
98
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
99
100
101
)
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
Mick's avatar
Mick committed
102
    MultimodalInputs,
103
104
    Req,
    ScheduleBatch,
105
    global_server_args_dict,
106
)
107
108
109
110
111
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
112
113
114
from sglang.srt.managers.scheduler_output_processor_mixin import (
    SchedulerOutputProcessorMixin,
)
115
from sglang.srt.managers.session_controller import Session
116
from sglang.srt.managers.tp_worker import TpModelWorker
117
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
118
from sglang.srt.managers.utils import validate_input_length
119
from sglang.srt.mem_cache.chunk_cache import ChunkCache
120
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
121
from sglang.srt.mem_cache.radix_cache import RadixCache
122
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
Lianmin Zheng's avatar
Lianmin Zheng committed
123
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
124
from sglang.srt.reasoning_parser import ReasoningParser
125
from sglang.srt.server_args import PortArgs, ServerArgs
126
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
127
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
128
from sglang.srt.utils import (
129
    DynamicGradMode,
130
131
    broadcast_pyobj,
    configure_logger,
132
    crash_on_warnings,
Lianmin Zheng's avatar
Lianmin Zheng committed
133
    disable_request_logging,
134
    get_bool_env_var,
135
    get_zmq_socket,
Lianmin Zheng's avatar
Lianmin Zheng committed
136
    kill_itself_when_parent_died,
137
    point_to_point_pyobj,
138
    pyspy_dump_schedulers,
139
    set_gpu_proc_affinity,
140
141
142
    set_random_seed,
    suppress_other_loggers,
)
143
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
144

145
146
expert_distribution_recorder = ExpertDistributionRecorder()

147
148
logger = logging.getLogger(__name__)

149
# Test retract decode for debugging purposes
150
151
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
152
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
153

154

155
156
@dataclass
class GenerationBatchResult:
157
158
159
    logits_output: Optional[LogitsProcessorOutput]
    pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
    next_token_ids: Optional[List[int]]
160
161
    extend_input_len_per_req: List[int]
    extend_logprob_start_len_per_req: List[int]
162
    bid: int
163
    can_run_cuda_graph: bool
164
165
166
167
168
169
170
171


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


Byron Hsu's avatar
Byron Hsu committed
172
173
174
175
176
class Scheduler(
    SchedulerOutputProcessorMixin,
    SchedulerDisaggregationDecodeMixin,
    SchedulerDisaggregationPrefillMixin,
):
177
178
179
180
181
182
183
184
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
185
        pp_rank: int,
186
        dp_rank: Optional[int],
187
188
    ):
        # Parse args
189
        self.server_args = server_args
190
        self.tp_rank = tp_rank
191
        self.pp_rank = pp_rank
192
        self.tp_size = server_args.tp_size
193
194
        self.pp_size = server_args.pp_size
        self.dp_size = server_args.dp_size
195
196
197
        self.schedule_policy = server_args.schedule_policy
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
198
        self.enable_overlap = not server_args.disable_overlap_schedule
199
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
200
        self.enable_metrics = server_args.enable_metrics
201
        self.stream_interval = server_args.stream_interval
202
203
204
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
205
206
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
Lianmin Zheng's avatar
Lianmin Zheng committed
207
        self.page_size = server_args.page_size
208

209
        # Distributed rank info
210
211
212
213
214
215
216
217
218
        self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

219
220
        # Init inter-process communication
        context = zmq.Context(2)
221
        if self.pp_rank == 0 and self.attn_tp_rank == 0:
222
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
223
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
224
            )
225
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
226
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
227
            )
228

229
            if server_args.skip_tokenizer_init:
230
                # Directly send to the TokenizerManager
231
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
232
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
233
234
                )
            else:
235
                # Send to the DetokenizerManager
236
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
237
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
238
                )
239
240
241
242

            self.recv_from_rpc = get_zmq_socket(
                context, zmq.DEALER, port_args.rpc_ipc_name, False
            )
243
        else:
244
            self.recv_from_tokenizer = None
245
            self.recv_from_rpc = None
246
247
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
248
249

        # Init tokenizer
250
        self.init_tokenizer()
251

252
253
254
255
256
257
258
259
260
        # Set reasoning_parser and think_end_id if --reasoning_parser is enabled
        if self.server_args.reasoning_parser and self.tokenizer:
            reasoning_parser = ReasoningParser(
                model_type=self.server_args.reasoning_parser, stream_reasoning=False
            )
            self.tokenizer.think_end_id = self.tokenizer.encode(
                reasoning_parser.detector.think_end_token, add_special_tokens=False
            )[0]

261
262
263
264
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
265

266
        # Launch a tensor parallel worker
267
        if self.enable_overlap:
268
            TpWorkerClass = TpModelWorkerClient
269
270
        else:
            TpWorkerClass = TpModelWorker
271

272
        self.tp_worker = TpWorkerClass(
273
            server_args=server_args,
274
275
            gpu_id=gpu_id,
            tp_rank=tp_rank,
276
            pp_rank=pp_rank,
277
            dp_rank=dp_rank,
278
            nccl_port=port_args.nccl_port,
279
        )
280

281
        # Launch a draft worker for speculative decoding
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
        else:
            self.draft_worker = None

296
        # Get token and memory info from the model worker
297
298
299
300
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
301
            self.max_req_len,
302
303
            self.max_req_input_len,
            self.random_seed,
304
            self.device,
305
306
307
308
309
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
310
311
312
313
314
315
316
317
        if global_server_args_dict["max_micro_batch_size"] is None:
            global_server_args_dict["max_micro_batch_size"] = max(
                self.max_running_requests // server_args.pp_size, 1
            )

        self.tp_group = self.tp_worker.get_tp_group()
        self.tp_cpu_group = self.tp_group.cpu_group
        self.attn_tp_group = self.tp_worker.get_attention_tp_group()
318
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
319
320
321
        self.pp_group = get_pp_group()
        self.world_group = get_world_group()

322
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
323
        global_server_args_dict.update(worker_global_server_args_dict)
324
        set_random_seed(self.random_seed)
325

326
        # Print debug info
327
328
329
330
331
332
333
334
        if tp_rank == 0:
            logger.info(
                f"max_total_num_tokens={self.max_total_num_tokens}, "
                f"chunked_prefill_size={server_args.chunked_prefill_size}, "
                f"max_prefill_tokens={self.max_prefill_tokens}, "
                f"max_running_requests={self.max_running_requests}, "
                f"context_len={self.model_config.context_len}"
            )
335

Lianmin Zheng's avatar
Lianmin Zheng committed
336
        # Init memory pool and cache
337
        self.init_memory_pool_and_cache()
338
339
340

        # Init running status
        self.waiting_queue: List[Req] = []
341
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
342
        self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
343
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
344
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
345
        # The last forward batch
346
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
347
348
        self.forward_ct = 0
        self.forward_ct_decode = 0
349
        self.num_generated_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
350
        self.num_prefill_tokens = 0
351
        self.last_decode_stats_tic = time.time()
Lianmin Zheng's avatar
Lianmin Zheng committed
352
        self.last_prefill_stats_tic = time.time()
353
        self.return_health_check_ct = 0
354
        self.current_stream = torch.get_device_module(self.device).current_stream()
355
356
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
357

358
        # Init session info
359
        self.sessions: Dict[str, Session] = {}
360
361
362

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
363
364
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
365
        self.chunked_req = None
366
367
368
369
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
370
        # Init the grammar backend for constrained generation
371
        self.grammar_queue: List[Req] = []
372
        if not server_args.skip_tokenizer_init:
373
374
375
            self.grammar_backend = create_grammar_backend(
                server_args, self.tokenizer, self.model_config.vocab_size
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
376
377
        else:
            self.grammar_backend = None
378

379
        # Init schedule policy and new token estimation
380
        self.policy = SchedulePolicy(
Lianmin Zheng's avatar
Lianmin Zheng committed
381
382
383
            self.schedule_policy,
            self.tree_cache,
            self.enable_hierarchical_cache,
384
        )
385
386
387
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
388
389
        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
390
391
            * server_args.schedule_conservativeness,
            1.0,
392
        )
393
394
395
396
397
398
399
400
401
402
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
403
404
405
406
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
407
        self.parent_process = psutil.Process().parent()
Lianmin Zheng's avatar
Lianmin Zheng committed
408

409
        # Init memory saver
410
411
412
413
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )

414
        # Init profiler
415
416
        self.torch_profiler = None
        self.torch_profiler_output_dir: Optional[str] = None
417
        self.profiler_activities: Optional[List[str]] = None
418
        self.profiler_id: Optional[str] = None
419
        self.profiler_target_forward_ct: Optional[int] = None
420

421
422
        self.forward_sleep_time = None

423
        # Init metrics stats
424
        self.init_metrics()
425

426
427
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
428
429
430
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
431
                (FlushCacheReqInput, self.flush_cache_wrapped),
432
                (AbortReq, self.abort_request),
433
434
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
435
436
437
438
439
440
441
442
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
443
444
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
445
                (SlowDownReqInput, self.slow_down),
446
                (ProfileReq, self.profile),
447
                (GetInternalStateReq, self.get_internal_state),
448
                (SetInternalStateReq, self.set_internal_state),
449
                (RpcReqInput, self.handle_rpc_request),
450
                (ExpertDistributionReq, self.expert_distribution_handle),
451
452
453
            ]
        )

Byron Hsu's avatar
Byron Hsu committed
454
455
456
457
458
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
        self.init_disaggregation()

459
460
    def init_tokenizer(self):
        server_args = self.server_args
Lianmin Zheng's avatar
Lianmin Zheng committed
461

462
        self.model_config = ModelConfig.from_server_args(server_args)
463
        self.is_generation = self.model_config.is_generation
464

465
466
467
468
469
470
471
472
473
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if self.model_config.is_multimodal:
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
474
                    use_fast=not server_args.disable_fast_image_processor,
475
                )
xm:D's avatar
xm:D committed
476
                self.tokenizer = get_tokenizer_from_processor(self.processor)
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )

    def init_memory_pool_and_cache(self):
        server_args = self.server_args

        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            self.tp_worker.get_memory_pool()
        )

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
499
                page_size=self.page_size,
500
501
502
503
504
505
            )
        else:
            if self.enable_hierarchical_cache:
                self.tree_cache = HiRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
506
                    tp_cache_group=self.tp_cpu_group,
507
                    page_size=self.page_size,
508
                    hicache_ratio=server_args.hicache_ratio,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
509
510
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
511
512
513
514
515
                )
            else:
                self.tree_cache = RadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
516
                    page_size=self.page_size,
517
518
519
520
521
522
523
524
525
526
527
528
529
                    disable=server_args.disable_radix_cache,
                )

        self.decode_mem_cache_buf_multiplier = (
            1
            if self.spec_algorithm.is_none()
            else (
                server_args.speculative_num_draft_tokens
                + (
                    server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                )
            )
530
        )
531
532
533
534
535
536
537

    def init_metrics(self):
        # The largest prefill length of a single request
        self._largest_prefill_len: int = 0
        # The largest context length (prefill + generation) of a single request
        self._largest_prefill_decode_len: int = 0
        self.last_gen_throughput: float = 0.0
Lianmin Zheng's avatar
Lianmin Zheng committed
538
        self.last_input_throughput: float = 0.0
539
540
541
542
543
544
545
546
547
548
549
550
551
552
        self.step_time_dict = defaultdict(list)  # Dict[batch size -> step time]
        self.spec_num_total_accepted_tokens = 0
        self.spec_num_total_forward_ct = 0
        self.cum_spec_accept_length = 0
        self.cum_spec_accept_count = 0
        self.stats = SchedulerStats()
        if self.enable_metrics:
            engine_type = "unified"
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    "engine_type": engine_type,
                },
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
553

Byron Hsu's avatar
Byron Hsu committed
554
    def init_disaggregation(self):
555
556
557
558
        self.transfer_backend = TransferBackend(
            self.server_args.disaggregation_transfer_backend
        )

Byron Hsu's avatar
Byron Hsu committed
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
        if (
            self.disaggregation_mode == DisaggregationMode.DECODE
        ):  # *2 for the headroom.
            buffer_size = (self.req_to_token_pool.size) * 2
            req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
                buffer_size
            )
            aux_dtype = torch.int32
            # A list of metadata buffers. The shape is (b, metadata_size) where
            # b corresponds to a max running requests. The last shape * dtype.itemsize
            # should be larger than 64 bytes to work with RDMA, so we pad it.
            output_id_buffer = torch.zeros(
                (buffer_size, 16), dtype=aux_dtype, device="cpu"
            )
            metadata_buffers = [output_id_buffer]

            # The decode requests polling kv cache
            self.disagg_decode_transfer_queue = DecodeTransferQueue(
577
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
578
579
580
581
582
583
584
585
586
587
588
589
590
591
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
                metadata_buffers=metadata_buffers,
            )

            # The decode requests pending for pre-allocation
            self.disagg_decode_prealloc_queue = DecodePreallocQueue(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
                metadata_buffers=metadata_buffers,
                aux_dtype=aux_dtype,
                scheduler=self,
                transfer_queue=self.disagg_decode_transfer_queue,
                tree_cache=self.tree_cache,
592
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
593
594
595
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
596
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
597
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
598
599
600
601

            # Metric for pre-allocation
            self.num_tokens_pre_allocated = 0

Byron Hsu's avatar
Byron Hsu committed
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
        elif self.disaggregation_mode == DisaggregationMode.PREFILL:
            # *2 for the headroom.
            buffer_size = self.max_running_requests * 2
            req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
                buffer_size
            )
            aux_dtype = torch.int32
            # A list of metadata buffers. The shape is (b, metadata_size) where
            # b corresponds to a max running requests. The last shape * dtype.itemsize
            # should be larger than 64 bytes to work with RDMA, so we pad it.
            output_id_buffer = torch.zeros(
                (buffer_size, 16), dtype=aux_dtype, device="cpu"
            )
            metadata_buffers = [output_id_buffer]

Liangsheng Yin's avatar
Liangsheng Yin committed
617
            self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
Byron Hsu's avatar
Byron Hsu committed
618
619
620
621
622
623
624
                token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
                metadata_buffers=metadata_buffers,
                aux_dtype=aux_dtype,
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
625
                gloo_group=self.attn_tp_cpu_group,
626
                transfer_backend=self.transfer_backend,
627
                scheduler=self,
Byron Hsu's avatar
Byron Hsu committed
628
629
            )
            # The prefill requests that are in the middle of kv sending
630
            self.disagg_prefill_inflight_queue: List[Req] = []
Byron Hsu's avatar
Byron Hsu committed
631

632
    @DynamicGradMode()
633
    def event_loop_normal(self):
634
        """A normal scheduler loop."""
635
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
636
637
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
638

639
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
640
            self.cur_batch = batch
641
642
643
644

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
645
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
646
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
647
                self.check_memory()
648
                self.new_token_ratio = self.init_new_token_ratio
649
650

            self.last_batch = batch
651

652
    @DynamicGradMode()
Lianmin Zheng's avatar
Lianmin Zheng committed
653
    def event_loop_overlap(self):
654
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
655
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
656
657
658
659
660
661
662

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
663

Lianmin Zheng's avatar
Lianmin Zheng committed
664
            if batch:
665
                batch.launch_done = threading.Event()
Lianmin Zheng's avatar
Lianmin Zheng committed
666
                result = self.run_batch(batch)
667
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
668

669
                if self.last_batch is None:
670
                    # Create a dummy first batch to start the pipeline for overlap schedule.
671
672
673
674
675
676
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
677
                    self.process_batch_result(tmp_batch, None, batch.launch_done)
678

Lianmin Zheng's avatar
Lianmin Zheng committed
679
            if self.last_batch:
680
                # Process the results of the last batch
681
                tmp_batch, tmp_result = self.result_queue.popleft()
682
683
684
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
685
686
687
688
                # NOTE: we should use current launched batch's launch_done event Instead of the last batch's
                self.process_batch_result(
                    tmp_batch, tmp_result, batch.launch_done if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
689
            elif batch is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
690
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
691
                self.check_memory()
692
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
693
694
695

            self.last_batch = batch

696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
    @DynamicGradMode()
    def event_loop_pp(self):
        """A non-overlap scheduler loop for pipeline parallelism."""
        mbs = [None] * self.pp_size
        last_mbs = [None] * self.pp_size
        self.running_mbs = [
            ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
        ]
        bids = [None] * self.pp_size
        pp_outputs: Optional[PPProxyTensors] = None
        while True:
            server_is_idle = True
            for mb_id in range(self.pp_size):
                self.running_batch = self.running_mbs[mb_id]
                self.last_batch = last_mbs[mb_id]

                recv_reqs = self.recv_requests()
                self.process_input_requests(recv_reqs)
                mbs[mb_id] = self.get_next_batch_to_run()
                self.running_mbs[mb_id] = self.running_batch

                self.cur_batch = mbs[mb_id]
                if self.cur_batch:
                    server_is_idle = False
                    result = self.run_batch(self.cur_batch)

                # send the outputs to the next step
                if self.pp_group.is_last_rank:
                    if self.cur_batch:
                        next_token_ids, bids[mb_id] = (
                            result.next_token_ids,
                            result.bid,
                        )
                        pp_outputs = PPProxyTensors(
                            {
                                "next_token_ids": next_token_ids,
                            }
                        )
                        # send the output from the last round to let the next stage worker run post processing
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                # receive outputs and post-process (filter finished reqs) the coming microbatch
                next_mb_id = (mb_id + 1) % self.pp_size
                next_pp_outputs = None
                if mbs[next_mb_id] is not None:
                    next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
                        self.pp_group.recv_tensor_dict(
                            all_gather_group=self.attn_tp_group
                        )
                    )
                    mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
                    output_result = GenerationBatchResult(
                        logits_output=None,
                        pp_hidden_states_proxy_tensors=None,
                        next_token_ids=next_pp_outputs["next_token_ids"],
                        extend_input_len_per_req=None,
                        extend_logprob_start_len_per_req=None,
                        bid=bids[next_mb_id],
757
                        can_run_cuda_graph=result.can_run_cuda_graph,
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
                    )
                    self.process_batch_result(mbs[next_mb_id], output_result)
                    last_mbs[next_mb_id] = mbs[next_mb_id]

                # carry the outputs to the next stage
                if not self.pp_group.is_last_rank:
                    if self.cur_batch:
                        bids[mb_id] = result.bid
                    if pp_outputs:
                        # send the outputs from the last round to let the next stage worker run post processing
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                if not self.pp_group.is_last_rank:
                    # send out reqs to the next stage
                    dp_offset = self.dp_rank * self.attn_tp_size
                    if self.attn_tp_rank == 0:
                        point_to_point_pyobj(
                            recv_reqs,
                            self.pp_rank * self.tp_size + dp_offset,
                            self.world_group.cpu_group,
                            self.pp_rank * self.tp_size + dp_offset,
                            (self.pp_rank + 1) * self.tp_size + dp_offset,
                        )

                    # send out proxy tensors to the next stage
                    if self.cur_batch:
                        self.pp_group.send_tensor_dict(
                            result.pp_hidden_states_proxy_tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                pp_outputs = next_pp_outputs

            # When the server is idle, self-check and re-init some states
            if server_is_idle:
                self.check_memory()
                self.new_token_ratio = self.init_new_token_ratio

799
800
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
        if self.pp_rank == 0:
            if self.attn_tp_rank == 0:
                recv_reqs = []

                while True:
                    try:
                        recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_req)

                while True:
                    try:
                        recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_rpc)
            else:
                recv_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
820
        else:
821
822
823
824
825
826
827
828
829
830
831
            if self.attn_tp_rank == 0:
                dp_offset = self.dp_rank * self.attn_tp_size
                recv_reqs = point_to_point_pyobj(
                    [],
                    self.pp_rank * self.tp_size + dp_offset,
                    self.world_group.cpu_group,
                    (self.pp_rank - 1) * self.tp_size + dp_offset,
                    self.pp_rank * self.tp_size + dp_offset,
                )
            else:
                recv_reqs = None
832

833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                work_reqs = broadcast_pyobj(
                    work_reqs,
856
                    self.attn_tp_group.rank,
857
                    self.attn_tp_cpu_group,
858
                    src=self.attn_tp_group.ranks[0],
859
860
861
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
862
863
864
865
                    control_reqs,
                    self.tp_group.rank,
                    self.tp_cpu_group,
                    src=self.tp_group.ranks[0],
866
867
868
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
869
870
871
872
873
874
            recv_reqs = broadcast_pyobj(
                recv_reqs,
                self.tp_group.rank,
                self.tp_cpu_group,
                src=self.tp_group.ranks[0],
            )
875
876
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
877
    def process_input_requests(self, recv_reqs: List):
878
        for recv_req in recv_reqs:
879
880
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
Lianmin Zheng's avatar
Lianmin Zheng committed
881
                self.chunked_req is not None or not self.running_batch.is_empty()
882
883
884
885
            ):
                self.return_health_check_ct += 1
                continue

886
            output = self._request_dispatcher(recv_req)
887
            if output is not None:
888
889
890
891
892
                if isinstance(output, RpcReqOutput):
                    if self.recv_from_rpc is not None:
                        self.recv_from_rpc.send_pyobj(output)
                else:
                    self.send_to_tokenizer.send_pyobj(output)
893
894
895
896
897

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
898
        # Create a new request
899
900
901
902
903
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
904
905
906
907
908
909
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

910
911
912
913
            if recv_req.bootstrap_port is None:
                # Use default bootstrap port
                recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port

914
915
916
917
918
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
919
920
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
921
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
922
                stream=recv_req.stream,
923
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
924
                input_embeds=recv_req.input_embeds,
Lianmin Zheng's avatar
Lianmin Zheng committed
925
                custom_logit_processor=recv_req.custom_logit_processor,
926
                return_hidden_states=recv_req.return_hidden_states,
927
                eos_token_ids=self.model_config.hf_eos_token_id,
928
                bootstrap_host=recv_req.bootstrap_host,
929
                bootstrap_port=recv_req.bootstrap_port,
930
                bootstrap_room=recv_req.bootstrap_room,
931
932
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
933

934
935
936
937
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
938
                req.finished_reason = FINISH_ABORT(
939
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
940
                )
941
                self._add_request_to_queue(req)
942
943
                return
        else:
944
945
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
946
            req = session.create_req(recv_req, self.tokenizer)
947
            if isinstance(req.finished_reason, FINISH_ABORT):
948
                self._add_request_to_queue(req)
949
                return
950

951
        # Handle multimodal inputs
Mick's avatar
Mick committed
952
953
        if recv_req.mm_inputs is not None:
            image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
954
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
955
            req.origin_input_ids = self.pad_input_ids_func(
956
                req.origin_input_ids, image_inputs
957
            )
958
            req.extend_image_inputs(image_inputs)
959

960
            if len(req.origin_input_ids) >= self.max_req_input_len:
961
                error_msg = (
962
                    "Multimodal prompt is too long after expanding multimodal tokens. "
963
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
964
                )
965
                logger.error(error_msg)
966
                req.origin_input_ids = [0]
Mick's avatar
Mick committed
967
                req.multimodal_inputs = None
968
                req.sampling_params.max_new_tokens = 0
969
                req.finished_reason = FINISH_ABORT(
970
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
971
                )
972
                self._add_request_to_queue(req)
973
974
                return

975
976
977
978
979
980
981
        # Validate prompts length
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
982
983
            req.origin_input_ids = [0]
            req.sampling_params.max_new_tokens = 0
984
            self._add_request_to_queue(req)
985
            return
986

987
        # Copy more attributes
988
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
989
990
991
992
993
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(req.origin_input_ids) - 1
        else:
            req.logprob_start_len = recv_req.logprob_start_len

994
995
996
997
998
999
1000
1001
1002
1003
        if req.logprob_start_len >= len(req.origin_input_ids):
            req.finished_reason = FINISH_ABORT(
                f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
                HTTPStatus.BAD_REQUEST,
                "BadRequestError",
            )
            req.logprob_start_len = len(req.origin_input_ids) - 1
            self._add_request_to_queue(req)
            return

1004
1005
1006
1007
1008
1009
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
1010
            self.max_req_len - len(req.origin_input_ids) - 1,
1011
1012
        )

1013
1014
1015
1016
1017
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
1018
            or req.sampling_params.ebnf is not None
1019
            or req.sampling_params.structural_tag is not None
1020
1021
1022
1023
1024
1025
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
1026
1027
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
1028
1029
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
1030

1031
1032
1033
1034
1035
            value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
            req.grammar = value

            if not cache_hit:
                req.grammar_key = key
1036
1037
1038
                add_to_grammar_queue = True

        if add_to_grammar_queue:
1039
            req.queue_time_start = time.time()
1040
1041
            self.grammar_queue.append(req)
        else:
1042
1043
1044
            self._add_request_to_queue(req)

    def _add_request_to_queue(self, req: Req):
1045
        req.queue_time_start = time.time()
Byron Hsu's avatar
Byron Hsu committed
1046
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
Liangsheng Yin's avatar
Liangsheng Yin committed
1047
            self.disagg_prefill_bootstrap_queue.add(req)
Byron Hsu's avatar
Byron Hsu committed
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.add(req)
        else:
            self.waiting_queue.append(req)

    def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
        if self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.extend(reqs)
        else:
            self.waiting_queue.extend(reqs)
1058
1059
1060

    def handle_embedding_request(
        self,
1061
        recv_req: TokenizedEmbeddingReqInput,
1062
1063
1064
1065
1066
1067
1068
1069
1070
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

1071
1072
        # Handle multimodal inputs
        if recv_req.image_inputs is not None:
Mick's avatar
Mick committed
1073
            image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
            req.origin_input_ids = self.pad_input_ids_func(
                req.origin_input_ids, image_inputs
            )
            req.extend_image_inputs(image_inputs)

            if len(req.origin_input_ids) >= self.max_req_input_len:
                error_msg = (
                    "Multimodal prompt is too long after expanding multimodal tokens. "
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                )
                logger.error(error_msg)
                req.origin_input_ids = [0]
Mick's avatar
Mick committed
1087
                req.multimodal_inputs = None
1088
1089
1090
1091
                req.sampling_params.max_new_tokens = 0
                req.finished_reason = FINISH_ABORT(
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
                )
1092
                req.queue_time_start = time.time()
1093
1094
1095
                self.waiting_queue.append(req)
                return

1096
        # Validate prompts length
1097
        error_msg = validate_input_length(
1098
1099
1100
1101
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
1102
        if error_msg:
1103
            self._add_request_to_queue(req)
1104
            return
1105

1106
1107
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
1108
        self._add_request_to_queue(req)
1109

1110
1111
1112
1113
    def log_prefill_stats(
        self,
        adder: PrefillAdder,
        can_run_list: List[Req],
1114
        running_bs: int,
1115
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1116
1117
1118
1119
1120
        gap_latency = time.time() - self.last_prefill_stats_tic
        self.last_prefill_stats_tic = time.time()
        self.last_input_throughput = self.num_prefill_tokens / gap_latency
        self.num_prefill_tokens = 0

1121
        num_used = self.max_total_num_tokens - (
1122
1123
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
1124
        )
1125
1126
1127
        self._largest_prefill_len = max(
            self._largest_prefill_len, adder.log_input_tokens
        )
1128

1129
        num_new_seq = len(can_run_list)
1130
        f = (
1131
            f"Prefill batch. "
1132
            f"#new-seq: {num_new_seq}, "
1133
1134
1135
1136
1137
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
1138
1139
1140
1141
1142
1143
1144
1145

        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
            f += f"#queue-req: {len(self.waiting_queue)}, "
            f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)} "
        else:
            f += f"#queue-req: {len(self.waiting_queue)}"

1146
        logger.info(f)
1147
1148

        if self.enable_metrics:
1149
1150
1151
            cache_hit_rate = adder.log_hit_tokens / (
                adder.log_input_tokens + adder.log_hit_tokens
            )
1152
1153
1154
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
1155
1156
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.stats.cache_hit_rate = cache_hit_rate
1157
1158
1159
1160
1161
1162

            total_queue_latency = 0
            for req in can_run_list:
                total_queue_latency += req.queue_time_end - req.queue_time_start
            self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq

1163
1164
            self.metrics_collector.log_stats(self.stats)

1165
1166
1167
    def log_decode_stats(
        self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
    ):
1168
1169
        batch = running_batch or self.running_batch

1170
1171
1172
1173
        gap_latency = time.time() - self.last_decode_stats_tic
        self.last_decode_stats_tic = time.time()
        self.last_gen_throughput = self.num_generated_tokens / gap_latency
        self.num_generated_tokens = 0
1174
        num_running_reqs = len(batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1175
        num_used = self.max_total_num_tokens - (
1176
1177
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1178
        )
1179
1180
1181
1182
1183

        if RECORD_STEP_TIME:
            self.step_time_dict[num_running_reqs].append(
                gap_latency / self.server_args.decode_log_interval
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1184

Liangsheng Yin's avatar
Liangsheng Yin committed
1185
1186
1187
1188
1189
1190
1191
        msg = (
            f"Decode batch. "
            f"#running-req: {num_running_reqs}, "
            f"#token: {num_used}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
        )

1192
        if self.spec_algorithm.is_none():
1193
            spec_accept_length = 0
1194
        else:
1195
            spec_accept_length = (
1196
1197
                self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
            )
1198
1199
            self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
            self.cum_spec_accept_count += self.spec_num_total_forward_ct
1200
            self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
1201
1202
1203
1204
1205
1206
            msg += f"accept len: {spec_accept_length:.2f}, "

        if self.disaggregation_mode == DisaggregationMode.DECODE:
            msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "

        msg += (
1207
            f"cuda graph: {can_run_cuda_graph}, "
Liangsheng Yin's avatar
Liangsheng Yin committed
1208
1209
1210
            f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
            f"#queue-req: {len(self.waiting_queue)}"
        )
1211
1212

        logger.info(msg)
1213
1214
1215
1216
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
1217
1218
            self.stats.cache_hit_rate = 0.0
            self.stats.gen_throughput = self.last_gen_throughput
1219
            self.stats.num_queue_reqs = len(self.waiting_queue)
1220
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1221
            self.stats.spec_accept_length = spec_accept_length
1222
1223
            self.metrics_collector.log_stats(self.stats)

Lianmin Zheng's avatar
Lianmin Zheng committed
1224
1225
    def check_memory(self):
        available_size = (
1226
1227
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1228
        )
1229
1230
1231
1232
1233
1234
1235
        protected_size = self.tree_cache.protected_size()
        memory_leak = available_size != (
            self.max_total_num_tokens
            if not self.enable_hierarchical_cache
            else self.max_total_num_tokens - protected_size
        )
        if memory_leak:
1236
            msg = (
1237
                "token_to_kv_pool_allocator memory leak detected! "
1238
                f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1239
1240
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1241
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1242
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1243
1244

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
1245
            msg = (
1246
                "req_to_token_pool memory leak detected!"
1247
1248
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1249
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1250
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1251

1252
1253
1254
1255
1256
1257
1258
        if (
            self.enable_metrics
            and self.attn_tp_rank == 0
            and time.time() > self.metrics_collector.last_log_time + 30
        ):
            # During idle time, also collect metrics every 30 seconds.
            num_used = self.max_total_num_tokens - (
1259
                self.token_to_kv_pool_allocator.available_size()
1260
1261
                + self.tree_cache.evictable_size()
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1262
            num_running_reqs = len(self.running_batch.reqs)
1263
1264
1265
1266
1267
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
1268
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1269
1270
            self.metrics_collector.log_stats(self.stats)

1271
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1272
        # Merge the prefill batch into the running batch
1273
1274
1275
1276
1277
1278
1279
1280
        chunked_req_to_exclude = set()
        if self.chunked_req:
            # Move the chunked request out of the batch so that we can merge
            # only finished requests to running_batch.
            chunked_req_to_exclude.add(self.chunked_req)
            self.tree_cache.cache_unfinished_req(self.chunked_req)
            # chunked request keeps its rid but will get a new req_pool_idx
            self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
Lianmin Zheng's avatar
Lianmin Zheng committed
1281
        if self.last_batch and self.last_batch.forward_mode.is_extend():
1282
1283
1284
1285
            if self.last_batch.chunked_req is not None:
                # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
                # We need to discard it.
                chunked_req_to_exclude.add(self.last_batch.chunked_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1286

1287
            # Filter batch
1288
            last_bs = self.last_batch.batch_size()
1289
1290
1291
            self.last_batch.filter_batch(
                chunked_req_to_exclude=list(chunked_req_to_exclude)
            )
1292
            if self.last_batch.batch_size() < last_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1293
                self.running_batch.batch_is_full = False
1294

1295
            # Merge the new batch into the running batch
1296
            if not self.last_batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1297
                if self.running_batch.is_empty():
1298
1299
                    self.running_batch = self.last_batch
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1300
                    # Merge running_batch with prefill batch
1301
                    self.running_batch.merge_batch(self.last_batch)
1302

1303
1304
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
1305
1306
1307
1308
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
1309
            if not self.running_batch.is_empty():
1310
                self.running_batch = self.update_running_batch(self.running_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1311
1312
1313
                ret = self.running_batch if not self.running_batch.is_empty() else None
            else:
                ret = None
1314

1315
        # Handle DP attention
1316
        if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
Lianmin Zheng's avatar
Lianmin Zheng committed
1317
            ret, _ = self.prepare_dp_attn_batch(ret)
1318
1319

        return ret
1320

1321
1322
1323
1324
1325
1326
    def get_num_allocatable_reqs(self, running_bs):
        res = global_server_args_dict["max_micro_batch_size"] - running_bs
        if self.pp_size > 1:
            res = min(res, self.req_to_token_pool.available_size())
        return res

Lianmin Zheng's avatar
Lianmin Zheng committed
1327
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
1328
        # Check if the grammar is ready in the grammar queue
1329
        if self.grammar_queue:
1330
            self.move_ready_grammar_requests()
1331

Lianmin Zheng's avatar
Lianmin Zheng committed
1332
1333
        # Handle the cases where prefill is not allowed
        if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1334
            self.running_batch.batch_is_full or len(self.waiting_queue) == 0
1335
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1336
1337
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1338
        running_bs = len(self.running_batch.reqs)
1339
        # Ignore the check if self.chunked_req is not None.
1340
1341
1342
1343
1344
        # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
        # as the space for the chunked request has just been released.
        # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
        # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
        if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
Lianmin Zheng's avatar
Lianmin Zheng committed
1345
            self.running_batch.batch_is_full = True
1346
1347
            return None

1348
1349
1350
1351
1352
        if self.enable_hierarchical_cache:
            # check for completion of hierarchical cache activities to release memory
            self.tree_cache.writing_check()
            self.tree_cache.loading_check()

1353
1354
1355
        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
1356
        # Prefill policy
1357
1358
        adder = PrefillAdder(
            self.tree_cache,
1359
            self.token_to_kv_pool_allocator,
1360
1361
1362
1363
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1364
            running_bs if self.is_mixed_chunk else 0,
1365
1366
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1367
        if self.chunked_req is not None:
1368
1369
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1370

Lianmin Zheng's avatar
Lianmin Zheng committed
1371
        if self.lora_paths:
Lianmin Zheng's avatar
Lianmin Zheng committed
1372
1373
            lora_set = set([req.lora_path for req in self.running_batch.reqs])

1374
        # Get requests from the waiting queue to a new prefill batch
1375
1376
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1377
                self.lora_paths
1378
1379
1380
1381
1382
1383
1384
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1385
                self.running_batch.batch_is_full = True
1386
1387
                break

1388
            if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
Lianmin Zheng's avatar
Lianmin Zheng committed
1389
                self.running_batch.batch_is_full = True
1390
                break
1391

1392
1393
1394
1395
            req.init_next_round_input(
                None if prefix_computed else self.tree_cache,
                self.enable_hierarchical_cache,
            )
1396

1397
1398
1399
            res = adder.add_one_req(
                req, self.chunked_req, self.enable_hierarchical_cache
            )
1400

1401
1402
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1403
1404
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
Lianmin Zheng's avatar
Lianmin Zheng committed
1405
1406
                        self.running_batch.batch_is_full = len(
                            adder.can_run_list
1407
                        ) > 0 or (not self.running_batch.is_empty())
1408
                    else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1409
                        self.running_batch.batch_is_full = True
1410
1411
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1412
        # Update waiting queue
1413
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1414
1415
        if len(can_run_list) == 0:
            return None
1416
1417
1418
1419
1420
1421

        if self.enable_metrics:
            # only record queue time when enable_metrics is True to avoid overhead
            for req in can_run_list:
                req.queue_time_end = time.time()

Lianmin Zheng's avatar
Lianmin Zheng committed
1422
1423
1424
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1425

1426
        if self.enable_hierarchical_cache:
1427
            self.tree_cache.ready_to_load_cache()
1428

1429
1430
1431
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1432

1433
1434
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1435

1436
        # Print stats
1437
        if self.attn_tp_rank == 0:
1438
            self.log_prefill_stats(adder, can_run_list, running_bs)
1439

Lianmin Zheng's avatar
Lianmin Zheng committed
1440
        # Create a new batch
1441
1442
1443
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
1444
            self.token_to_kv_pool_allocator,
1445
            self.tree_cache,
1446
            self.model_config,
1447
            self.enable_overlap,
1448
            self.spec_algorithm,
1449
            self.server_args.enable_custom_logit_processor,
1450
            chunked_req=self.chunked_req,
1451
        )
1452
        new_batch.prepare_for_extend()
1453

Lianmin Zheng's avatar
Lianmin Zheng committed
1454
        # Mixed-style chunked prefill
1455
1456
        if (
            self.is_mixed_chunk
Lianmin Zheng's avatar
Lianmin Zheng committed
1457
            and not self.running_batch.is_empty()
1458
1459
1460
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1461
1462
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1463
                self.running_batch.prepare_for_decode()
1464
1465
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1466
1467
1468
            self.running_batch = ScheduleBatch(
                reqs=[], batch_is_full=self.running_batch.batch_is_full
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1469
1470
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1471
1472
1473

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1474
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1475
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1476
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1477

1478
1479
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1480
1481
            batch.batch_is_full = False
            return batch
1482

Lianmin Zheng's avatar
Lianmin Zheng committed
1483
        # Check if decode out of memory
1484
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1485
            TEST_RETRACT and batch.batch_size() > 10
1486
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1487
1488
            old_ratio = self.new_token_ratio

1489
            retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
Lianmin Zheng's avatar
Lianmin Zheng committed
1490
            self.new_token_ratio = new_token_ratio
1491

Lianmin Zheng's avatar
Lianmin Zheng committed
1492
1493
1494
1495
1496
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
1497
            self._extend_requests_to_queue(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1498
1499
        else:
            self.new_token_ratio = max(
1500
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1501
1502
1503
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1504
        if batch.batch_size() < initial_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1505
            batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1506
1507

        # Update batch tensors
1508
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1509
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1510

1511
1512
1513
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1514
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1515
1516
        self.forward_ct += 1

1517
1518
1519
1520
1521
1522
1523
        # Check profiler
        if (
            self.profiler_target_forward_ct
            and self.profiler_target_forward_ct <= self.forward_ct
        ):
            self.stop_profile()

1524
1525
1526
1527
        if self.forward_sleep_time is not None:
            logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
            time.sleep(self.forward_sleep_time)

1528
        # Run forward
1529
        if self.is_generation:
1530
1531
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
1532
                if self.pp_group.is_last_rank:
1533
                    logits_output, next_token_ids, can_run_cuda_graph = (
1534
1535
1536
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
                else:
1537
                    pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
1538
1539
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
1540
                bid = model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
1541
            else:
1542
1543
1544
                (
                    logits_output,
                    next_token_ids,
1545
                    bid,
1546
                    num_accepted_tokens,
1547
                    can_run_cuda_graph,
1548
1549
1550
1551
1552
1553
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
                self.spec_num_total_accepted_tokens += (
                    num_accepted_tokens + batch.batch_size()
                )
                self.spec_num_total_forward_ct += batch.batch_size()
                self.num_generated_tokens += num_accepted_tokens
1554
1555
1556

            if self.pp_group.is_last_rank:
                batch.output_ids = next_token_ids
1557

1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
            if batch.return_logprob:
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_input_len_per_req = None
                extend_logprob_start_len_per_req = None

1570
            ret = GenerationBatchResult(
1571
1572
1573
1574
1575
1576
1577
                logits_output=logits_output if self.pp_group.is_last_rank else None,
                pp_hidden_states_proxy_tensors=(
                    pp_hidden_states_proxy_tensors
                    if not self.pp_group.is_last_rank
                    else None
                ),
                next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
1578
1579
                extend_input_len_per_req=extend_input_len_per_req,
                extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1580
                bid=bid,
1581
                can_run_cuda_graph=can_run_cuda_graph,
1582
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1583
1584
1585
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1586
1587
1588
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1589
        return ret
Chayenne's avatar
Chayenne committed
1590

1591
1592
1593
1594
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
1595
        launch_done: Optional[threading.Event] = None,
1596
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1597
        if batch.forward_mode.is_decode():
1598
            self.process_batch_result_decode(batch, result, launch_done)
1599
        elif batch.forward_mode.is_extend():
1600
            self.process_batch_result_prefill(batch, result, launch_done)
1601
1602
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1603
                self.tp_worker.resolve_last_batch_result(launch_done)
1604
1605
1606
1607
                if batch.next_batch_sampling_info:
                    batch.next_batch_sampling_info.update_regex_vocab_mask()
                    self.current_stream.synchronize()
                    batch.next_batch_sampling_info.sampling_info_done.set()
1608
1609
        elif batch.forward_mode.is_dummy_first():
            batch.next_batch_sampling_info.update_regex_vocab_mask()
1610
            self.current_stream.synchronize()
1611
            batch.next_batch_sampling_info.sampling_info_done.set()
Lianmin Zheng's avatar
Lianmin Zheng committed
1612

1613
1614
1615
1616
1617
1618
1619
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

1620
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
        return self.prepare_dp_attn_batch_raw(
            local_batch,
            dp_size=self.server_args.dp_size,
            attn_tp_size=self.attn_tp_size,
            tp_cpu_group=self.tp_cpu_group,
            get_idle_batch=self.get_idle_batch,
            disable_cuda_graph=self.server_args.disable_cuda_graph,
            spec_algorithm=self.spec_algorithm,
            speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
        )

    @staticmethod
    def prepare_dp_attn_batch_raw(
        local_batch: ScheduleBatch,
        dp_size,
        attn_tp_size: int,
        tp_cpu_group,
        get_idle_batch,
        disable_cuda_graph: bool,
        spec_algorithm,
        speculative_num_draft_tokens,
    ):
1643
1644
1645
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
1646
            global_num_tokens_for_logprob = 0
1647
1648
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
1649
1650
            if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
                num_tokens = num_tokens * speculative_num_draft_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
1651
            global_num_tokens_for_logprob = num_tokens
1652
1653
        else:
            num_tokens = local_batch.extend_num_tokens
Lianmin Zheng's avatar
Lianmin Zheng committed
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
            global_num_tokens_for_logprob = sum(
                [
                    # We should have at least 1 token for sample in every case.
                    max(extend_len - logprob_start_len, 1)
                    for logprob_start_len, extend_len in zip(
                        local_batch.extend_logprob_start_lens, local_batch.extend_lens
                    )
                ]
            )

        if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
            can_cuda_graph = 1
        else:
            can_cuda_graph = 0

1669
        if not spec_algorithm.is_none():
applesaucethebun's avatar
applesaucethebun committed
1670
            # TODO(sang): Support CUDA graph when idle batch is there.
Lianmin Zheng's avatar
Lianmin Zheng committed
1671
1672
            if local_batch is None or local_batch.forward_mode.is_idle():
                can_cuda_graph = 0
1673

Lianmin Zheng's avatar
Lianmin Zheng committed
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
        is_extend_in_batch = (
            local_batch.forward_mode.is_extend() if local_batch else False
        )
        local_info = torch.tensor(
            [
                num_tokens,
                can_cuda_graph,
                global_num_tokens_for_logprob,
                is_extend_in_batch,
            ],
            dtype=torch.int64,
        )
        global_info = torch.empty(
1687
            (dp_size, attn_tp_size, 4),
Lianmin Zheng's avatar
Lianmin Zheng committed
1688
1689
            dtype=torch.int64,
        )
1690
        torch.distributed.all_gather_into_tensor(
Lianmin Zheng's avatar
Lianmin Zheng committed
1691
1692
            global_info.flatten(),
            local_info,
1693
            group=tp_cpu_group,
1694
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1695
1696
1697
1698
        global_num_tokens = global_info[:, 0, 0].tolist()
        can_cuda_graph = min(global_info[:, 0, 1].tolist())
        global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
        is_extend_in_batch = global_info[:, 0, 3].tolist()
1699

Lianmin Zheng's avatar
Lianmin Zheng committed
1700
        if local_batch is None and max(global_num_tokens) > 0:
1701
            local_batch = get_idle_batch()
1702
1703

        if local_batch is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1704
1705
            local_batch.global_num_tokens = global_num_tokens
            local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
1706

applesaucethebun's avatar
applesaucethebun committed
1707
            # Check forward mode for CUDA graph
1708
            if not disable_cuda_graph:
Lianmin Zheng's avatar
Lianmin Zheng committed
1709
                local_batch.can_run_dp_cuda_graph = can_cuda_graph
1710

Lianmin Zheng's avatar
Lianmin Zheng committed
1711
        return local_batch, any(is_extend_in_batch)
1712
1713
1714
1715
1716

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
1717
            self.token_to_kv_pool_allocator,
1718
1719
1720
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
1721
            self.spec_algorithm,
1722
            self.server_args.enable_custom_logit_processor,
1723
1724
1725
1726
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1727
1728
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
1729

1730
        num_ready_reqs = 0
1731
        num_abort_reqs = 0
1732
1733
        for req in self.grammar_queue:
            try:
1734
1735
1736
                req.grammar = req.grammar.result(timeout=0.03)
                if req.grammar:
                    self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
1737
1738
                num_ready_reqs += 1
            except futures._base.TimeoutError:
1739
1740
1741
                req.grammar_wait_ct += 1
                if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
                    num_abort_reqs = 1
1742
1743
                break

1744
        if self.server_args.enable_dp_attention:
1745
1746
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
1747
        else:
1748
1749
1750
1751
1752
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
1753
            tensor = torch.tensor([num_ready_reqs, num_abort_reqs], dtype=torch.int32)
1754
1755
1756
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
1757
1758
            num_ready_reqs_max, num_abort_reqs_max = tensor.tolist()

1759
            for i in range(num_ready_reqs, num_ready_reqs_max):
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
                req = self.grammar_queue[i]
                req.grammar = req.grammar.result()
                if req.grammar:
                    self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())

            for i in range(num_ready_reqs, num_ready_reqs + num_abort_reqs_max):
                req = self.grammar_queue[i]
                req.grammar.cancel()
                req.grammar = None
                error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
                logger.error(error_msg)
                req.finished_reason = FINISH_ABORT(
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
                )
            num_ready_reqs = num_ready_reqs_max + num_abort_reqs_max
1775

1776
        self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1777
1778
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
    def watchdog_thread(self):
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
        self.watchdog_last_forward_ct = 0
        self.watchdog_last_time = time.time()

        while True:
            current = time.time()
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if current > self.watchdog_last_time + self.watchdog_timeout:
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)

Lianmin Zheng's avatar
Lianmin Zheng committed
1795
1796
1797
1798
1799
1800
1801
1802
1803
        if not disable_request_logging():
            # Print batch size and memory pool info to check whether there are de-sync issues.
            logger.error(
                f"{self.cur_batch.batch_size()=}, "
                f"{self.cur_batch.reqs=}, "
                f"{self.token_to_kv_pool_allocator.available_size()=}, "
                f"{self.tree_cache.evictable_size()=}, "
            )

1804
        pyspy_dump_schedulers()
Lianmin Zheng's avatar
Lianmin Zheng committed
1805
        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
1806
1807
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
1808
1809

        # Wait for some time so that the parent process can print the error.
1810
1811
1812
        time.sleep(5)
        self.parent_process.send_signal(signal.SIGQUIT)

1813
1814
1815
    def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
        success = self.flush_cache()
        return FlushCacheReqOutput(success=success)
1816

1817
    def flush_cache(self):
1818
        """Flush the memory pool and cache."""
1819
1820
1821
1822
1823
        if (
            len(self.waiting_queue) == 0
            and self.running_batch.is_empty()
            and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
        ):
1824
1825
            self.cur_batch = None
            self.last_batch = None
1826
            self.tree_cache.reset()
1827
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1828
                self.grammar_backend.reset()
1829
            self.req_to_token_pool.clear()
1830
            self.token_to_kv_pool_allocator.clear()
1831
1832
1833

            if not self.spec_algorithm.is_none():
                self.draft_worker.model_runner.req_to_token_pool.clear()
1834
                self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
1835
1836
1837
1838
1839

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
1840
1841
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
1842
1843
1844
1845
1846
1847
1848
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
1849
                f"#running-req: {len(self.running_batch.reqs)}"
1850
1851
1852
1853
            )
            if_success = False
        return if_success

1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )
        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
        return GetInternalStateReqOutput(
            internal_state=ret,
        )

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
1871
                "max_micro_batch_size",
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
1882
1883
1884
1885
1886
1887
1888
1889
            elif k == "max_micro_batch_size" and (
                v > self.max_running_requests // self.pp_size or v < 1
            ):
                logging.warning(
                    f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
                )
                if_success = False
                break
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
            logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
    def handle_rpc_request(self, recv_req: RpcReqInput):
        # Handle RPC requests
        logger.info(
            f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
        )

        success = True
        exec = None
        try:
            func = getattr(self, recv_req.method)
            func(recv_req.parameters)
        except Exception as e:
            success = False
            exec = e
            logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")

        barrier()
        return RpcReqOutput(success, "" if not exec else str(exec))

    def save_remote_model(self, params):
        url = params["url"]

1927
        worker = self.tp_worker.worker
1928
1929
1930
1931

        worker.model_runner.save_remote_model(url)

    def save_sharded_model(self, params):
1932
        worker = self.tp_worker.worker
1933
1934
1935
1936
1937
1938
1939

        worker.model_runner.save_sharded_model(
            path=params["path"],
            pattern=params["pattern"],
            max_size=params["max_size"],
        )

1940
    def abort_request(self, recv_req: AbortReq):
Lianmin Zheng's avatar
Lianmin Zheng committed
1941
1942
        # TODO(lmzheng): abort the requests in the grammar queue.

1943
        # Delete requests in the waiting queue
Lianmin Zheng's avatar
Lianmin Zheng committed
1944
        to_del = []
1945
        for i, req in enumerate(self.waiting_queue):
Lianmin Zheng's avatar
Lianmin Zheng committed
1946
1947
            if req.rid.startswith(recv_req.rid):
                to_del.append(i)
1948

Lianmin Zheng's avatar
Lianmin Zheng committed
1949
        # Sort in reverse order to avoid index issues when deleting
Lianmin Zheng's avatar
Lianmin Zheng committed
1950
        for i in reversed(to_del):
Lianmin Zheng's avatar
Lianmin Zheng committed
1951
            req = self.waiting_queue.pop(i)
Lianmin Zheng's avatar
Lianmin Zheng committed
1952
            self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
1953
            logger.debug(f"Abort queued request. {req.rid=}")
1954
1955

        # Delete requests in the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1956
1957
1958
1959
1960
1961
        if self.cur_batch is self.running_batch or self.cur_batch is None:
            reqs = self.running_batch.reqs
        else:
            reqs = self.running_batch.reqs + self.cur_batch.reqs

        for req in reqs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1962
1963
1964
            if req.rid.startswith(recv_req.rid) and not req.finished():
                logger.debug(f"Abort running request. {req.rid=}")
                req.to_abort = True
1965

1966
1967
1968
    def _pause_engine(self) -> Tuple[List[Req], int]:
        raise NotImplementedError()

Chayenne's avatar
Chayenne committed
1969
1970
1971
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
1972
1973
1974
1975
1976
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
1977
        return UpdateWeightFromDiskReqOutput(success, message, 0)
1978

1979
1980
1981
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
1982
        return InitWeightsUpdateGroupReqOutput(success, message)
1983
1984

    def update_weights_from_distributed(
1985
1986
1987
        self,
        recv_req: UpdateWeightsFromDistributedReqInput,
    ) -> Tuple[bool, str]:
1988
1989
1990
1991
1992
1993
1994
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
1995
        return UpdateWeightsFromDistributedReqOutput(success, message)
1996

1997
1998
1999
2000
2001
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        """Update the online model parameter from tensors."""
        success, message = self.tp_worker.update_weights_from_tensor(recv_req)
        # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
        if success:
2002
2003
2004
            if recv_req.flush_cache:
                flash_cache_success = self.flush_cache()
                assert flash_cache_success, "Cache flush failed after updating weights"
2005
2006
        else:
            logger.error(message)
2007
        return UpdateWeightsFromTensorReqOutput(success, message)
2008

2009
2010
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
2011
        return GetWeightsByNameReqOutput(parameter)
2012

2013
    def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
2014
2015
2016
        self.memory_saver_adapter.check_validity(
            caller_name="release_memory_occupation"
        )
2017
2018
2019
2020
2021
        self.stashed_model_static_state = _export_static_state(
            self.tp_worker.worker.model_runner.model
        )
        self.memory_saver_adapter.pause()
        self.flush_cache()
2022
        return ReleaseMemoryOccupationReqOutput()
2023

2024
    def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
2025
        self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
2026
2027
2028
2029
2030
        self.memory_saver_adapter.resume()
        _import_static_state(
            self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
        )
        del self.stashed_model_static_state
2031
2032
        return ResumeMemoryOccupationReqOutput()

2033
2034
2035
2036
2037
2038
2039
    def slow_down(self, recv_req: SlowDownReqInput):
        t = recv_req.forward_sleep_time
        if t is not None and t <= 0:
            t = None
        self.forward_sleep_time = t
        return SlowDownReqOutput()

2040
    def profile(self, recv_req: ProfileReq):
2041
2042
        if recv_req.type == ProfileReqType.START_PROFILE:
            return self.start_profile(
2043
2044
2045
2046
2047
                recv_req.output_dir,
                recv_req.num_steps,
                recv_req.activities,
                recv_req.with_stack,
                recv_req.record_shapes,
2048
                recv_req.profile_id,
2049
            )
2050
        else:
2051
2052
2053
2054
2055
2056
2057
            return self.stop_profile()

    def start_profile(
        self,
        output_dir: Optional[str],
        num_steps: Optional[int],
        activities: Optional[List[str]],
2058
2059
        with_stack: Optional[bool],
        record_shapes: Optional[bool],
2060
        profile_id: Optional[str],
2061
    ) -> None:
2062
        if self.profiler_activities:
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
            return ProfileReqOutput(
                success=False,
                message="Profiling is already in progress. Call /stop_profile first.",
            )

        if output_dir is None:
            output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
        if activities is None:
            activities = ["CPU", "GPU"]

        self.torch_profiler_output_dir = output_dir
2074
        self.profiler_activities = activities
2075
        self.profiler_id = profile_id
2076
        logger.info(
2077
            "Profiling starts. Traces will be saved to: %s (with id %s)",
2078
            self.torch_profiler_output_dir,
2079
            self.profiler_id,
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
        )

        activity_map = {
            "CPU": torch.profiler.ProfilerActivity.CPU,
            "GPU": torch.profiler.ProfilerActivity.CUDA,
        }
        torchprof_activities = [
            activity_map[a] for a in activities if a in activity_map
        ]

        if torchprof_activities:
            self.torch_profiler = torch.profiler.profile(
                activities=torchprof_activities,
2093
2094
                with_stack=with_stack if with_stack is not None else True,
                record_shapes=record_shapes if record_shapes is not None else False,
2095
2096
2097
2098
2099
            )
            self.torch_profiler.start()

        if "MEM" in activities:
            torch.cuda.memory._record_memory_history(max_entries=100000)
2100

2101
2102
2103
        if "CUDA_PROFILER" in activities:
            torch.cuda.cudart().cudaProfilerStart()

2104
2105
2106
2107
2108
2109
        if num_steps:
            self.profiler_target_forward_ct = self.forward_ct + num_steps
            # The caller will be notified when reaching profiler_target_forward_ct
        else:
            self.profiler_target_forward_ct = None
            return ProfileReqOutput(success=True, message="Succeeded")
2110
2111

    def stop_profile(self) -> None:
2112
        if self.profiler_activities is None:
2113
2114
2115
2116
2117
2118
2119
2120
            return

        logger.info("Stop profiling...")
        if self.torch_profiler is not None:
            self.torch_profiler.stop()
            self.torch_profiler.export_chrome_trace(
                os.path.join(
                    self.torch_profiler_output_dir,
2121
                    self.profiler_id + f"-TP-{self.tp_rank}" + ".trace.json.gz",
2122
2123
2124
                )
            )

2125
        if "MEM" in self.profiler_activities:
2126
            memory_profile_path = os.path.join(
2127
                self.torch_profiler_output_dir,
2128
                self.profiler_id + f"-TP-{self.tp_rank}-memory" + ".pickle",
2129
2130
2131
2132
            )
            torch.cuda.memory._dump_snapshot(memory_profile_path)
            torch.cuda.memory._record_memory_history(enabled=None)

2133
2134
2135
        if "CUDA_PROFILER" in self.profiler_activities:
            torch.cuda.cudart().cudaProfilerStop()

2136
2137
2138
        logger.info(
            "Profiling done. Traces are saved to: %s",
            self.torch_profiler_output_dir,
2139
        )
2140
2141
        self.torch_profiler = None
        self.torch_profiler_output_dir = None
2142
        self.profiler_activities = None
2143
2144
2145
2146
2147

        if self.profiler_target_forward_ct:
            self.send_to_tokenizer.send_pyobj(
                ProfileReqOutput(success=True, message="Succeeded.")
            )
2148

2149
2150
2151
2152
2153
2154
2155
2156
2157
    def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
        if recv_req == ExpertDistributionReq.START_RECORD:
            expert_distribution_recorder.start_record()
        elif recv_req == ExpertDistributionReq.STOP_RECORD:
            expert_distribution_recorder.stop_record()
        elif recv_req == ExpertDistributionReq.DUMP_RECORD:
            expert_distribution_recorder.dump_record()
        else:
            raise ValueError("Unrecognized ExpertDistributionReq value")
2158
        return ExpertDistributionReqOutput()
2159

2160
    def open_session(self, recv_req: OpenSessionReqInput):
2161
2162
2163
2164
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
2165
            return OpenSessionReqOutput(session_id, False)
2166
        elif session_id is None:
2167
            logger.warning("session id is None, cannot open.")
2168
            return OpenSessionReqOutput(session_id, False)
2169
2170
2171
2172
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
2173
            return OpenSessionReqOutput(session_id, True)
2174
2175
2176
2177
2178
2179
2180
2181
2182

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
    def get_print_prefix(self):
        prefix = ""
        if self.dp_rank is not None:
            prefix += f" DP{self.dp_rank}"
        if self.server_args.tp_size > 1:
            prefix += f" TP{self.tp_rank}"
        if self.pp_size > 1:
            prefix += f" PP{self.pp_rank}"
        return prefix

2193

2194
2195
2196
2197
def is_health_check_generate_req(recv_req):
    return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")


2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
def _export_static_state(model):
    return dict(
        buffers=[
            (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
        ]
    )


def _import_static_state(model, static_params):
    self_named_buffers = dict(model.named_buffers())
    for name, tensor in static_params["buffers"]:
        self_named_buffers[name][...] = tensor


2212
2213
2214
2215
2216
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
2217
    pp_rank: int,
2218
    dp_rank: Optional[int],
2219
    pipe_writer,
2220
):
2221
    # Generate the prefix
2222
2223
2224
2225
2226
2227
2228
    prefix = ""
    if dp_rank is not None:
        prefix += f" DP{dp_rank}"
    if server_args.tp_size > 1:
        prefix += f" TP{tp_rank}"
    if server_args.pp_size > 1:
        prefix += f" PP{pp_rank}"
2229

2230
    # Config the process
2231
    kill_itself_when_parent_died()
2232
    setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
2233
    faulthandler.enable()
2234
    parent_process = psutil.Process().parent()
2235

2236
2237
2238
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
2239

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
2240
    # Configure the logger
2241
    configure_logger(server_args, prefix=prefix)
2242
    suppress_other_loggers()
2243

2244
    # Set cpu affinity to this gpu process
2245
2246
2247
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

2248
    # Create a scheduler and run the event loop
2249
    try:
2250
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
2251
        pipe_writer.send(
Mick's avatar
Mick committed
2252
2253
2254
2255
2256
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2257
        )
Byron Hsu's avatar
Byron Hsu committed
2258
2259
2260
        disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode

        if disaggregation_mode == DisaggregationMode.NULL:
2261
2262
2263
            if server_args.pp_size > 1:
                scheduler.event_loop_pp()
            elif scheduler.enable_overlap:
Byron Hsu's avatar
Byron Hsu committed
2264
2265
2266
2267
                scheduler.event_loop_overlap()
            else:
                scheduler.event_loop_normal()
        elif disaggregation_mode == DisaggregationMode.PREFILL:
2268
2269
2270
2271
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_prefill()
            else:
                scheduler.event_loop_normal_disagg_prefill()
2272

Byron Hsu's avatar
Byron Hsu committed
2273
        elif disaggregation_mode == DisaggregationMode.DECODE:
2274
2275
2276
2277
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_decode()
            else:
                scheduler.event_loop_normal_disagg_decode()
Byron Hsu's avatar
Byron Hsu committed
2278

2279
    except Exception:
2280
2281
2282
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)