scheduler.py 91.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
import time
23
from collections import defaultdict, deque
Lianmin Zheng's avatar
Lianmin Zheng committed
24
from concurrent import futures
25
from dataclasses import dataclass
26
from http import HTTPStatus
27
from types import SimpleNamespace
28
from typing import Dict, List, Optional, Tuple, Union
29

30
import psutil
31
import setproctitle
32
import torch
33
import zmq
34
from torch.distributed import barrier
35

36
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
37
from sglang.srt.configs.model_config import ModelConfig
38
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
Byron Hsu's avatar
Byron Hsu committed
39
40
41
42
43
from sglang.srt.disaggregation.decode import (
    DecodePreallocQueue,
    DecodeTransferQueue,
    SchedulerDisaggregationDecodeMixin,
)
44
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
Byron Hsu's avatar
Byron Hsu committed
45
46
47
48
49
50
from sglang.srt.disaggregation.prefill import (
    PrefillBootstrapQueue,
    SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
51
    MetadataBuffers,
Byron Hsu's avatar
Byron Hsu committed
52
    ReqToMetadataIdxAllocator,
53
    TransferBackend,
54
    prepare_abort,
Byron Hsu's avatar
Byron Hsu committed
55
)
56
from sglang.srt.distributed import get_pp_group, get_world_group
xm:D's avatar
xm:D committed
57
58
59
60
61
from sglang.srt.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
62
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
63
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
64
65
66
67
from sglang.srt.managers.expert_distribution import (
    ExpertDistributionRecorder,
    get_global_expert_distribution_recorder,
)
68
69
from sglang.srt.managers.io_struct import (
    AbortReq,
70
    CloseSessionReqInput,
71
    ExpertDistributionReq,
72
    ExpertDistributionReqOutput,
73
74
    FlushCacheReqInput,
    FlushCacheReqOutput,
75
76
    GetInternalStateReq,
    GetInternalStateReqOutput,
77
78
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
79
    HealthCheckOutput,
80
81
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
82
83
    OpenSessionReqInput,
    OpenSessionReqOutput,
84
    ProfileReq,
85
86
    ProfileReqOutput,
    ProfileReqType,
87
88
89
90
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
91
92
    RpcReqInput,
    RpcReqOutput,
93
94
    SetInternalStateReq,
    SetInternalStateReqOutput,
95
96
    SlowDownReqInput,
    SlowDownReqOutput,
97
98
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
99
100
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
101
102
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
103
104
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
105
)
106
from sglang.srt.managers.mm_utils import init_embedding_cache
107
108
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
Mick's avatar
Mick committed
109
    MultimodalInputs,
110
111
    Req,
    ScheduleBatch,
112
    global_server_args_dict,
113
)
114
115
116
117
118
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
119
120
121
from sglang.srt.managers.scheduler_output_processor_mixin import (
    SchedulerOutputProcessorMixin,
)
122
from sglang.srt.managers.session_controller import Session
123
from sglang.srt.managers.tp_worker import TpModelWorker
124
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
125
from sglang.srt.managers.utils import validate_input_length
126
from sglang.srt.mem_cache.chunk_cache import ChunkCache
127
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
128
from sglang.srt.mem_cache.radix_cache import RadixCache
129
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
Lianmin Zheng's avatar
Lianmin Zheng committed
130
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
131
from sglang.srt.reasoning_parser import ReasoningParser
132
from sglang.srt.server_args import PortArgs, ServerArgs
133
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
134
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
135
from sglang.srt.utils import (
136
    DynamicGradMode,
137
138
    broadcast_pyobj,
    configure_logger,
Lianmin Zheng's avatar
Lianmin Zheng committed
139
    disable_request_logging,
140
    get_bool_env_var,
141
    get_zmq_socket,
Lianmin Zheng's avatar
Lianmin Zheng committed
142
    kill_itself_when_parent_died,
143
    point_to_point_pyobj,
144
    pyspy_dump_schedulers,
145
    set_gpu_proc_affinity,
146
147
148
    set_random_seed,
    suppress_other_loggers,
)
149
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
150
151
152

logger = logging.getLogger(__name__)

153
# Test retract decode for debugging purposes
154
155
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
156
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
157

158

159
160
@dataclass
class GenerationBatchResult:
161
162
163
    logits_output: Optional[LogitsProcessorOutput]
    pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
    next_token_ids: Optional[List[int]]
164
165
    extend_input_len_per_req: List[int]
    extend_logprob_start_len_per_req: List[int]
166
    bid: int
167
    can_run_cuda_graph: bool
168
169
170
171
172
173
174
175


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


Byron Hsu's avatar
Byron Hsu committed
176
177
178
179
180
class Scheduler(
    SchedulerOutputProcessorMixin,
    SchedulerDisaggregationDecodeMixin,
    SchedulerDisaggregationPrefillMixin,
):
181
182
183
184
185
186
187
188
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
189
        pp_rank: int,
190
        dp_rank: Optional[int],
191
192
    ):
        # Parse args
193
        self.server_args = server_args
194
        self.tp_rank = tp_rank
195
        self.pp_rank = pp_rank
196
        self.tp_size = server_args.tp_size
197
198
        self.pp_size = server_args.pp_size
        self.dp_size = server_args.dp_size
199
200
201
        self.schedule_policy = server_args.schedule_policy
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
202
        self.enable_overlap = not server_args.disable_overlap_schedule
203
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
204
        self.enable_metrics = server_args.enable_metrics
205
        self.enable_kv_cache_events = server_args.kv_events_config is not None
206
        self.stream_interval = server_args.stream_interval
207
208
209
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
210
211
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
Lianmin Zheng's avatar
Lianmin Zheng committed
212
        self.page_size = server_args.page_size
213
        # Distributed rank info
214
215
        self.dp_size = server_args.dp_size
        self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
216
217
218
219
220
221
222
223
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

224
225
        # Init inter-process communication
        context = zmq.Context(2)
226
        if self.pp_rank == 0 and self.attn_tp_rank == 0:
227
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
228
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
229
            )
230
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
231
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
232
            )
233

234
            if server_args.skip_tokenizer_init:
235
                # Directly send to the TokenizerManager
236
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
237
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
238
239
                )
            else:
240
                # Send to the DetokenizerManager
241
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
242
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
243
                )
244
245
246
247

            self.recv_from_rpc = get_zmq_socket(
                context, zmq.DEALER, port_args.rpc_ipc_name, False
            )
248
        else:
249
            self.recv_from_tokenizer = None
250
            self.recv_from_rpc = None
251
252
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
253
254

        # Init tokenizer
255
        self.init_tokenizer()
256

257
258
259
260
261
262
263
264
265
        # Set reasoning_parser and think_end_id if --reasoning_parser is enabled
        if self.server_args.reasoning_parser and self.tokenizer:
            reasoning_parser = ReasoningParser(
                model_type=self.server_args.reasoning_parser, stream_reasoning=False
            )
            self.tokenizer.think_end_id = self.tokenizer.encode(
                reasoning_parser.detector.think_end_token, add_special_tokens=False
            )[0]

266
267
268
269
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
270

271
        # Launch a tensor parallel worker
272
        if self.enable_overlap:
273
            TpWorkerClass = TpModelWorkerClient
274
275
        else:
            TpWorkerClass = TpModelWorker
276

277
        self.tp_worker = TpWorkerClass(
278
            server_args=server_args,
279
280
            gpu_id=gpu_id,
            tp_rank=tp_rank,
281
            pp_rank=pp_rank,
282
            dp_rank=dp_rank,
283
            nccl_port=port_args.nccl_port,
284
        )
285

286
        # Launch a draft worker for speculative decoding
287
288
289
290
291
292
293
294
295
296
297
298
299
300
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
        else:
            self.draft_worker = None

301
        # Get token and memory info from the model worker
302
303
304
305
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
306
            self.max_req_len,
307
308
            self.max_req_input_len,
            self.random_seed,
309
            self.device,
310
311
312
313
314
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
315
316
317
318
319
320
321
322
        if global_server_args_dict["max_micro_batch_size"] is None:
            global_server_args_dict["max_micro_batch_size"] = max(
                self.max_running_requests // server_args.pp_size, 1
            )

        self.tp_group = self.tp_worker.get_tp_group()
        self.tp_cpu_group = self.tp_group.cpu_group
        self.attn_tp_group = self.tp_worker.get_attention_tp_group()
323
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
324
325
326
        self.pp_group = get_pp_group()
        self.world_group = get_world_group()

327
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
328
        global_server_args_dict.update(worker_global_server_args_dict)
329
        set_random_seed(self.random_seed)
330

331
        # Print debug info
332
333
334
335
336
337
338
339
        if tp_rank == 0:
            logger.info(
                f"max_total_num_tokens={self.max_total_num_tokens}, "
                f"chunked_prefill_size={server_args.chunked_prefill_size}, "
                f"max_prefill_tokens={self.max_prefill_tokens}, "
                f"max_running_requests={self.max_running_requests}, "
                f"context_len={self.model_config.context_len}"
            )
340

Lianmin Zheng's avatar
Lianmin Zheng committed
341
        # Init memory pool and cache
342
        self.init_memory_pool_and_cache()
343
344
345

        # Init running status
        self.waiting_queue: List[Req] = []
346
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
347
        self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
348
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
349
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
350
        # The last forward batch
351
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
352
353
        self.forward_ct = 0
        self.forward_ct_decode = 0
354
        self.num_generated_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
355
        self.num_prefill_tokens = 0
356
357
        self.last_decode_stats_tic = time.perf_counter()
        self.last_prefill_stats_tic = time.perf_counter()
358
        self.return_health_check_ct = 0
359
        self.current_stream = torch.get_device_module(self.device).current_stream()
360
361
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
362

363
        # Init session info
364
        self.sessions: Dict[str, Session] = {}
365
366
367

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
368
369
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
370
        self.chunked_req = None
371
372
373
374
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
375
        # Init the grammar backend for constrained generation
376
        self.grammar_queue: List[Req] = []
377
        if not server_args.skip_tokenizer_init:
378
379
380
            self.grammar_backend = create_grammar_backend(
                server_args, self.tokenizer, self.model_config.vocab_size
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
381
382
        else:
            self.grammar_backend = None
383

384
        # Init schedule policy and new token estimation
385
        self.policy = SchedulePolicy(
Lianmin Zheng's avatar
Lianmin Zheng committed
386
387
388
            self.schedule_policy,
            self.tree_cache,
            self.enable_hierarchical_cache,
389
        )
390
391
392
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
393
394
        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
395
396
            * server_args.schedule_conservativeness,
            1.0,
397
        )
398
399
400
401
402
403
404
405
406
407
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
408
409
410
411
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
412
        self.parent_process = psutil.Process().parent()
Lianmin Zheng's avatar
Lianmin Zheng committed
413

414
        # Init memory saver
415
416
417
418
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )

419
        # Init profiler
420
421
        self.torch_profiler = None
        self.torch_profiler_output_dir: Optional[str] = None
422
        self.profiler_activities: Optional[List[str]] = None
423
        self.profiler_id: Optional[str] = None
424
        self.profiler_target_forward_ct: Optional[int] = None
425

426
427
        self.forward_sleep_time = None

428
        # Init metrics stats
429
        self.init_metrics()
430
        self.init_kv_events(server_args.kv_events_config)
431

432
433
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
434
435
436
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
437
                (FlushCacheReqInput, self.flush_cache_wrapped),
438
                (AbortReq, self.abort_request),
439
440
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
441
442
443
444
445
446
447
448
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
449
450
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
451
                (SlowDownReqInput, self.slow_down),
452
                (ProfileReq, self.profile),
453
                (GetInternalStateReq, self.get_internal_state),
454
                (SetInternalStateReq, self.set_internal_state),
455
                (RpcReqInput, self.handle_rpc_request),
456
                (ExpertDistributionReq, self.expert_distribution_handle),
457
458
459
            ]
        )

Byron Hsu's avatar
Byron Hsu committed
460
461
462
463
464
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
        self.init_disaggregation()

465
466
    def init_tokenizer(self):
        server_args = self.server_args
Lianmin Zheng's avatar
Lianmin Zheng committed
467

468
        self.model_config = ModelConfig.from_server_args(server_args)
469
        self.is_generation = self.model_config.is_generation
470

471
472
473
474
475
476
477
478
479
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if self.model_config.is_multimodal:
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
480
                    use_fast=not server_args.disable_fast_image_processor,
481
                )
xm:D's avatar
xm:D committed
482
                self.tokenizer = get_tokenizer_from_processor(self.processor)
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )

    def init_memory_pool_and_cache(self):
        server_args = self.server_args

        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            self.tp_worker.get_memory_pool()
        )

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
505
                page_size=self.page_size,
506
507
508
509
510
511
            )
        else:
            if self.enable_hierarchical_cache:
                self.tree_cache = HiRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
512
                    tp_cache_group=self.tp_cpu_group,
513
                    page_size=self.page_size,
514
                    hicache_ratio=server_args.hicache_ratio,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
515
516
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
517
518
519
520
521
                )
            else:
                self.tree_cache = RadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
522
                    page_size=self.page_size,
523
                    disable=server_args.disable_radix_cache,
524
                    enable_kv_cache_events=self.enable_kv_cache_events,
525
526
527
528
529
530
531
532
533
534
535
536
                )

        self.decode_mem_cache_buf_multiplier = (
            1
            if self.spec_algorithm.is_none()
            else (
                server_args.speculative_num_draft_tokens
                + (
                    server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                )
            )
537
        )
538
539
540

    def init_metrics(self):
        self.last_gen_throughput: float = 0.0
Lianmin Zheng's avatar
Lianmin Zheng committed
541
        self.last_input_throughput: float = 0.0
542
543
544
545
546
547
548
549
550
551
552
553
554
555
        self.step_time_dict = defaultdict(list)  # Dict[batch size -> step time]
        self.spec_num_total_accepted_tokens = 0
        self.spec_num_total_forward_ct = 0
        self.cum_spec_accept_length = 0
        self.cum_spec_accept_count = 0
        self.stats = SchedulerStats()
        if self.enable_metrics:
            engine_type = "unified"
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    "engine_type": engine_type,
                },
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
556

557
558
559
560
    def init_kv_events(self, kv_events_config: Optional[str]):
        if self.enable_kv_cache_events:
            self.kv_event_publisher = EventPublisherFactory.create(kv_events_config)

Byron Hsu's avatar
Byron Hsu committed
561
    def init_disaggregation(self):
562
563
564
565
        self.transfer_backend = TransferBackend(
            self.server_args.disaggregation_transfer_backend
        )

Byron Hsu's avatar
Byron Hsu committed
566
567
568
569
570
571
572
        if (
            self.disaggregation_mode == DisaggregationMode.DECODE
        ):  # *2 for the headroom.
            buffer_size = (self.req_to_token_pool.size) * 2
            req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
                buffer_size
            )
573
            self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
Byron Hsu's avatar
Byron Hsu committed
574
575
576

            # The decode requests polling kv cache
            self.disagg_decode_transfer_queue = DecodeTransferQueue(
577
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
578
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
579
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
580
581
                scheduler=self,
                tree_cache=self.tree_cache,
Byron Hsu's avatar
Byron Hsu committed
582
583
584
585
586
587
            )

            # The decode requests pending for pre-allocation
            self.disagg_decode_prealloc_queue = DecodePreallocQueue(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Byron Hsu's avatar
Byron Hsu committed
588
589
590
591
592
                draft_token_to_kv_pool=(
                    None
                    if self.draft_worker is None
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
593
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
594
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
595
596
597
                scheduler=self,
                transfer_queue=self.disagg_decode_transfer_queue,
                tree_cache=self.tree_cache,
598
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
599
600
601
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
602
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
603
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
604
605
606
607

            # Metric for pre-allocation
            self.num_tokens_pre_allocated = 0

Byron Hsu's avatar
Byron Hsu committed
608
609
610
611
612
613
        elif self.disaggregation_mode == DisaggregationMode.PREFILL:
            # *2 for the headroom.
            buffer_size = self.max_running_requests * 2
            req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
                buffer_size
            )
614
            self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
Byron Hsu's avatar
Byron Hsu committed
615

Liangsheng Yin's avatar
Liangsheng Yin committed
616
            self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
Byron Hsu's avatar
Byron Hsu committed
617
                token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
Byron Hsu's avatar
Byron Hsu committed
618
619
620
621
622
                draft_token_to_kv_pool=(
                    None
                    if self.draft_worker is None
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
623
                req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
624
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
625
626
627
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
628
                gloo_group=self.attn_tp_cpu_group,
629
                transfer_backend=self.transfer_backend,
630
                scheduler=self,
Byron Hsu's avatar
Byron Hsu committed
631
632
            )
            # The prefill requests that are in the middle of kv sending
633
            self.disagg_prefill_inflight_queue: List[Req] = []
Byron Hsu's avatar
Byron Hsu committed
634

635
    @DynamicGradMode()
636
    def event_loop_normal(self):
637
        """A normal scheduler loop."""
638
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
639
640
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
641

642
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
643
            self.cur_batch = batch
644
645
646
647

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
648
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
649
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
650
                self.check_memory()
651
                self.new_token_ratio = self.init_new_token_ratio
652
653

            self.last_batch = batch
654

655
    @DynamicGradMode()
Lianmin Zheng's avatar
Lianmin Zheng committed
656
    def event_loop_overlap(self):
657
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
658
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
659
660
661
662
663
664
665

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
666

Lianmin Zheng's avatar
Lianmin Zheng committed
667
            if batch:
668
                batch.launch_done = threading.Event()
Lianmin Zheng's avatar
Lianmin Zheng committed
669
                result = self.run_batch(batch)
670
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
671

672
                if self.last_batch is None:
673
                    # Create a dummy first batch to start the pipeline for overlap schedule.
674
675
676
677
678
679
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
680
                    self.process_batch_result(tmp_batch, None, batch.launch_done)
681

Lianmin Zheng's avatar
Lianmin Zheng committed
682
            if self.last_batch:
683
                # Process the results of the last batch
684
                tmp_batch, tmp_result = self.result_queue.popleft()
685
686
687
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
688
689
690
691
                # NOTE: we should use current launched batch's launch_done event Instead of the last batch's
                self.process_batch_result(
                    tmp_batch, tmp_result, batch.launch_done if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
692
            elif batch is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
693
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
694
                self.check_memory()
695
                self.new_token_ratio = self.init_new_token_ratio
Lianmin Zheng's avatar
Lianmin Zheng committed
696
697
698

            self.last_batch = batch

699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
    @DynamicGradMode()
    def event_loop_pp(self):
        """A non-overlap scheduler loop for pipeline parallelism."""
        mbs = [None] * self.pp_size
        last_mbs = [None] * self.pp_size
        self.running_mbs = [
            ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
        ]
        bids = [None] * self.pp_size
        pp_outputs: Optional[PPProxyTensors] = None
        while True:
            server_is_idle = True
            for mb_id in range(self.pp_size):
                self.running_batch = self.running_mbs[mb_id]
                self.last_batch = last_mbs[mb_id]

                recv_reqs = self.recv_requests()
                self.process_input_requests(recv_reqs)
                mbs[mb_id] = self.get_next_batch_to_run()
                self.running_mbs[mb_id] = self.running_batch

                self.cur_batch = mbs[mb_id]
                if self.cur_batch:
                    server_is_idle = False
                    result = self.run_batch(self.cur_batch)

725
                # (last rank) send the outputs to the next step
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
                if self.pp_group.is_last_rank:
                    if self.cur_batch:
                        next_token_ids, bids[mb_id] = (
                            result.next_token_ids,
                            result.bid,
                        )
                        pp_outputs = PPProxyTensors(
                            {
                                "next_token_ids": next_token_ids,
                            }
                        )
                        # send the output from the last round to let the next stage worker run post processing
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                # receive outputs and post-process (filter finished reqs) the coming microbatch
                next_mb_id = (mb_id + 1) % self.pp_size
                next_pp_outputs = None
                if mbs[next_mb_id] is not None:
                    next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
                        self.pp_group.recv_tensor_dict(
                            all_gather_group=self.attn_tp_group
                        )
                    )
                    mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
                    output_result = GenerationBatchResult(
                        logits_output=None,
                        pp_hidden_states_proxy_tensors=None,
                        next_token_ids=next_pp_outputs["next_token_ids"],
                        extend_input_len_per_req=None,
                        extend_logprob_start_len_per_req=None,
                        bid=bids[next_mb_id],
760
                        can_run_cuda_graph=result.can_run_cuda_graph,
761
762
763
764
                    )
                    self.process_batch_result(mbs[next_mb_id], output_result)
                    last_mbs[next_mb_id] = mbs[next_mb_id]

765
                # (not last rank)
766
767
768
                if not self.pp_group.is_last_rank:
                    if self.cur_batch:
                        bids[mb_id] = result.bid
769
770
                    # carry the outputs to the next stage
                    # send the outputs from the last round to let the next stage worker run post processing
771
772
773
774
775
776
777
                    if pp_outputs:
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                    # send out reqs to the next stage
778
                    dp_offset = self.attn_dp_rank * self.attn_tp_size
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
                    if self.attn_tp_rank == 0:
                        point_to_point_pyobj(
                            recv_reqs,
                            self.pp_rank * self.tp_size + dp_offset,
                            self.world_group.cpu_group,
                            self.pp_rank * self.tp_size + dp_offset,
                            (self.pp_rank + 1) * self.tp_size + dp_offset,
                        )

                    # send out proxy tensors to the next stage
                    if self.cur_batch:
                        self.pp_group.send_tensor_dict(
                            result.pp_hidden_states_proxy_tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                pp_outputs = next_pp_outputs

            # When the server is idle, self-check and re-init some states
            if server_is_idle:
                self.check_memory()
                self.new_token_ratio = self.init_new_token_ratio

802
803
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
        if self.pp_rank == 0:
            if self.attn_tp_rank == 0:
                recv_reqs = []

                while True:
                    try:
                        recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_req)

                while True:
                    try:
                        recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_rpc)
            else:
                recv_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
823
        else:
824
            if self.attn_tp_rank == 0:
825
                dp_offset = self.attn_dp_rank * self.attn_tp_size
826
827
828
829
830
831
832
833
834
                recv_reqs = point_to_point_pyobj(
                    [],
                    self.pp_rank * self.tp_size + dp_offset,
                    self.world_group.cpu_group,
                    (self.pp_rank - 1) * self.tp_size + dp_offset,
                    self.pp_rank * self.tp_size + dp_offset,
                )
            else:
                recv_reqs = None
835

836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                work_reqs = broadcast_pyobj(
                    work_reqs,
859
                    self.attn_tp_group.rank,
860
                    self.attn_tp_cpu_group,
861
                    src=self.attn_tp_group.ranks[0],
862
863
864
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
865
866
867
868
                    control_reqs,
                    self.tp_group.rank,
                    self.tp_cpu_group,
                    src=self.tp_group.ranks[0],
869
870
871
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
872
873
874
875
876
877
            recv_reqs = broadcast_pyobj(
                recv_reqs,
                self.tp_group.rank,
                self.tp_cpu_group,
                src=self.tp_group.ranks[0],
            )
878
879
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
880
    def process_input_requests(self, recv_reqs: List):
881
        for recv_req in recv_reqs:
882
883
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
Lianmin Zheng's avatar
Lianmin Zheng committed
884
                self.chunked_req is not None or not self.running_batch.is_empty()
885
886
887
888
            ):
                self.return_health_check_ct += 1
                continue

889
            output = self._request_dispatcher(recv_req)
890
            if output is not None:
891
892
893
894
895
                if isinstance(output, RpcReqOutput):
                    if self.recv_from_rpc is not None:
                        self.recv_from_rpc.send_pyobj(output)
                else:
                    self.send_to_tokenizer.send_pyobj(output)
896
897
898
899
900

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
901
        # Create a new request
902
903
904
905
906
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
907
908
909
910
911
912
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

913
914
915
916
            if recv_req.bootstrap_port is None:
                # Use default bootstrap port
                recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port

917
918
919
920
921
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
922
923
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
924
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
925
                stream=recv_req.stream,
926
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
927
                input_embeds=recv_req.input_embeds,
Lianmin Zheng's avatar
Lianmin Zheng committed
928
                custom_logit_processor=recv_req.custom_logit_processor,
929
                return_hidden_states=recv_req.return_hidden_states,
930
                eos_token_ids=self.model_config.hf_eos_token_id,
931
                bootstrap_host=recv_req.bootstrap_host,
932
                bootstrap_port=recv_req.bootstrap_port,
933
                bootstrap_room=recv_req.bootstrap_room,
934
935
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
936

937
938
939
940
941
942
943
944
945
946
947
948
            if self.disaggregation_mode != DisaggregationMode.NULL:
                # Invalid request for disaggregated mode
                if recv_req.bootstrap_room is None:
                    error_message = (
                        f"Invalid request: Disaggregated request received without "
                        f"boostrap room id. {req.rid=}"
                    )
                    logger.error(error_message)
                    prepare_abort(req, error_message)
                    self.stream_output([req], req.return_logprob)
                    return

949
950
951
952
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
953
                req.finished_reason = FINISH_ABORT(
954
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
955
                )
956
                self._add_request_to_queue(req)
957
958
                return
        else:
959
960
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
961
            req = session.create_req(recv_req, self.tokenizer)
962
            if isinstance(req.finished_reason, FINISH_ABORT):
963
                self._add_request_to_queue(req)
964
                return
965

966
        # Handle multimodal inputs
Mick's avatar
Mick committed
967
968
        if recv_req.mm_inputs is not None:
            image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
969
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
970
            req.origin_input_ids = self.pad_input_ids_func(
971
                req.origin_input_ids, image_inputs
972
            )
973
            req.extend_image_inputs(image_inputs)
974

975
            if len(req.origin_input_ids) >= self.max_req_input_len:
976
                error_msg = (
977
                    "Multimodal prompt is too long after expanding multimodal tokens. "
978
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
979
                )
980
                logger.error(error_msg)
981
                req.origin_input_ids = [0]
Mick's avatar
Mick committed
982
                req.multimodal_inputs = None
983
                req.sampling_params.max_new_tokens = 0
984
                req.finished_reason = FINISH_ABORT(
985
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
986
                )
987
                self._add_request_to_queue(req)
988
989
                return

990
991
992
993
994
995
996
        # Validate prompts length
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
997
998
            req.origin_input_ids = [0]
            req.sampling_params.max_new_tokens = 0
999
            self._add_request_to_queue(req)
1000
            return
1001

1002
        # Copy more attributes
1003
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
1004
1005
1006
1007
1008
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(req.origin_input_ids) - 1
        else:
            req.logprob_start_len = recv_req.logprob_start_len

1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
        if req.logprob_start_len >= len(req.origin_input_ids):
            req.finished_reason = FINISH_ABORT(
                f"logprob_start_len, ({req.logprob_start_len}) is higher than the number of input tokens ({len(req.origin_input_ids)}). Request with a lower logprob_start_len.",
                HTTPStatus.BAD_REQUEST,
                "BadRequestError",
            )
            req.logprob_start_len = len(req.origin_input_ids) - 1
            self._add_request_to_queue(req)
            return

1019
1020
1021
1022
1023
1024
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
1025
            self.max_req_len - len(req.origin_input_ids) - 1,
1026
1027
        )

1028
1029
1030
1031
1032
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
1033
            or req.sampling_params.ebnf is not None
1034
            or req.sampling_params.structural_tag is not None
1035
1036
1037
1038
1039
1040
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
1041
1042
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
1043
1044
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
1045

1046
1047
1048
1049
1050
            value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
            req.grammar = value

            if not cache_hit:
                req.grammar_key = key
1051
1052
1053
                add_to_grammar_queue = True

        if add_to_grammar_queue:
1054
            req.queue_time_start = time.perf_counter()
1055
1056
            self.grammar_queue.append(req)
        else:
1057
1058
1059
            self._add_request_to_queue(req)

    def _add_request_to_queue(self, req: Req):
1060
        req.queue_time_start = time.perf_counter()
Byron Hsu's avatar
Byron Hsu committed
1061
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
Liangsheng Yin's avatar
Liangsheng Yin committed
1062
            self.disagg_prefill_bootstrap_queue.add(req)
Byron Hsu's avatar
Byron Hsu committed
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.add(req)
        else:
            self.waiting_queue.append(req)

    def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
        if self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.extend(reqs)
        else:
            self.waiting_queue.extend(reqs)
1073
1074
1075

    def handle_embedding_request(
        self,
1076
        recv_req: TokenizedEmbeddingReqInput,
1077
1078
1079
1080
1081
1082
1083
1084
1085
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

1086
1087
        # Handle multimodal inputs
        if recv_req.image_inputs is not None:
Mick's avatar
Mick committed
1088
            image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
            req.origin_input_ids = self.pad_input_ids_func(
                req.origin_input_ids, image_inputs
            )
            req.extend_image_inputs(image_inputs)

            if len(req.origin_input_ids) >= self.max_req_input_len:
                error_msg = (
                    "Multimodal prompt is too long after expanding multimodal tokens. "
                    f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                )
                logger.error(error_msg)
                req.origin_input_ids = [0]
Mick's avatar
Mick committed
1102
                req.multimodal_inputs = None
1103
1104
1105
1106
                req.sampling_params.max_new_tokens = 0
                req.finished_reason = FINISH_ABORT(
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
                )
1107
                req.queue_time_start = time.perf_counter()
1108
1109
1110
                self.waiting_queue.append(req)
                return

1111
        # Validate prompts length
1112
        error_msg = validate_input_length(
1113
1114
1115
1116
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
1117
        if error_msg:
1118
            self._add_request_to_queue(req)
1119
            return
1120

1121
1122
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
1123
        self._add_request_to_queue(req)
1124

1125
1126
1127
1128
    def log_prefill_stats(
        self,
        adder: PrefillAdder,
        can_run_list: List[Req],
1129
        running_bs: int,
1130
    ):
1131
1132
        gap_latency = time.perf_counter() - self.last_prefill_stats_tic
        self.last_prefill_stats_tic = time.perf_counter()
Lianmin Zheng's avatar
Lianmin Zheng committed
1133
1134
1135
        self.last_input_throughput = self.num_prefill_tokens / gap_latency
        self.num_prefill_tokens = 0

1136
        num_used = self.max_total_num_tokens - (
1137
1138
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
1139
1140
        )

1141
        num_new_seq = len(can_run_list)
1142
        f = (
1143
            f"Prefill batch. "
1144
            f"#new-seq: {num_new_seq}, "
1145
1146
1147
1148
1149
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
            f"#running-req: {running_bs}, "
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
1150
1151
1152
1153
1154
1155
1156
1157

        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
            f += f"#queue-req: {len(self.waiting_queue)}, "
            f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)} "
        else:
            f += f"#queue-req: {len(self.waiting_queue)}"

1158
        logger.info(f)
1159
1160

        if self.enable_metrics:
1161
1162
1163
            cache_hit_rate = adder.log_hit_tokens / (
                adder.log_input_tokens + adder.log_hit_tokens
            )
1164
1165
1166
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
1167
1168
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.stats.cache_hit_rate = cache_hit_rate
1169
1170
1171
1172
1173
1174

            total_queue_latency = 0
            for req in can_run_list:
                total_queue_latency += req.queue_time_end - req.queue_time_start
            self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq

1175
            self.metrics_collector.log_stats(self.stats)
1176
        self._publish_kv_events()
1177

1178
1179
1180
    def log_decode_stats(
        self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
    ):
1181
1182
        batch = running_batch or self.running_batch

1183
1184
        gap_latency = time.perf_counter() - self.last_decode_stats_tic
        self.last_decode_stats_tic = time.perf_counter()
1185
1186
        self.last_gen_throughput = self.num_generated_tokens / gap_latency
        self.num_generated_tokens = 0
1187
        num_running_reqs = len(batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1188
        num_used = self.max_total_num_tokens - (
1189
1190
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1191
        )
1192
1193
1194
1195
1196

        if RECORD_STEP_TIME:
            self.step_time_dict[num_running_reqs].append(
                gap_latency / self.server_args.decode_log_interval
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1197

Liangsheng Yin's avatar
Liangsheng Yin committed
1198
1199
1200
1201
1202
1203
1204
        msg = (
            f"Decode batch. "
            f"#running-req: {num_running_reqs}, "
            f"#token: {num_used}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
        )

1205
        if self.spec_algorithm.is_none():
1206
            spec_accept_length = 0
1207
        else:
1208
            spec_accept_length = (
1209
1210
                self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
            )
1211
1212
            self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
            self.cum_spec_accept_count += self.spec_num_total_forward_ct
1213
            self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
1214
1215
1216
1217
1218
1219
            msg += f"accept len: {spec_accept_length:.2f}, "

        if self.disaggregation_mode == DisaggregationMode.DECODE:
            msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "

        msg += (
1220
            f"cuda graph: {can_run_cuda_graph}, "
Liangsheng Yin's avatar
Liangsheng Yin committed
1221
1222
1223
            f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
            f"#queue-req: {len(self.waiting_queue)}"
        )
1224
1225

        logger.info(msg)
1226
1227
1228
1229
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
1230
1231
            self.stats.cache_hit_rate = 0.0
            self.stats.gen_throughput = self.last_gen_throughput
1232
            self.stats.num_queue_reqs = len(self.waiting_queue)
1233
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1234
            self.stats.spec_accept_length = spec_accept_length
1235
            self.metrics_collector.log_stats(self.stats)
1236
        self._publish_kv_events()
1237

Lianmin Zheng's avatar
Lianmin Zheng committed
1238
1239
    def check_memory(self):
        available_size = (
1240
1241
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1242
        )
1243
1244
1245
1246
1247
1248
1249
        protected_size = self.tree_cache.protected_size()
        memory_leak = available_size != (
            self.max_total_num_tokens
            if not self.enable_hierarchical_cache
            else self.max_total_num_tokens - protected_size
        )
        if memory_leak:
1250
            msg = (
1251
                "token_to_kv_pool_allocator memory leak detected! "
1252
                f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1253
1254
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1255
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1256
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1257
1258

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
1259
            msg = (
1260
                "req_to_token_pool memory leak detected!"
1261
1262
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1263
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1264
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1265

1266
1267
1268
        if (
            self.enable_metrics
            and self.attn_tp_rank == 0
1269
            and time.perf_counter() > self.metrics_collector.last_log_time + 30
1270
1271
1272
        ):
            # During idle time, also collect metrics every 30 seconds.
            num_used = self.max_total_num_tokens - (
1273
                self.token_to_kv_pool_allocator.available_size()
1274
1275
                + self.tree_cache.evictable_size()
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1276
            num_running_reqs = len(self.running_batch.reqs)
1277
1278
1279
1280
1281
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
1282
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1283
            self.metrics_collector.log_stats(self.stats)
1284
        self._publish_kv_events()
1285

1286
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1287
        # Merge the prefill batch into the running batch
1288
1289
1290
1291
1292
1293
1294
1295
        chunked_req_to_exclude = set()
        if self.chunked_req:
            # Move the chunked request out of the batch so that we can merge
            # only finished requests to running_batch.
            chunked_req_to_exclude.add(self.chunked_req)
            self.tree_cache.cache_unfinished_req(self.chunked_req)
            # chunked request keeps its rid but will get a new req_pool_idx
            self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
Lianmin Zheng's avatar
Lianmin Zheng committed
1296
        if self.last_batch and self.last_batch.forward_mode.is_extend():
1297
1298
1299
1300
            if self.last_batch.chunked_req is not None:
                # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
                # We need to discard it.
                chunked_req_to_exclude.add(self.last_batch.chunked_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1301

1302
            # Filter batch
1303
            last_bs = self.last_batch.batch_size()
1304
1305
1306
            self.last_batch.filter_batch(
                chunked_req_to_exclude=list(chunked_req_to_exclude)
            )
1307
            if self.last_batch.batch_size() < last_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1308
                self.running_batch.batch_is_full = False
1309

1310
            # Merge the new batch into the running batch
1311
            if not self.last_batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1312
                if self.running_batch.is_empty():
1313
1314
                    self.running_batch = self.last_batch
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1315
                    # Merge running_batch with prefill batch
1316
                    self.running_batch.merge_batch(self.last_batch)
1317

1318
1319
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
1320
1321
1322
1323
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
1324
            if not self.running_batch.is_empty():
1325
                self.running_batch = self.update_running_batch(self.running_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1326
1327
1328
                ret = self.running_batch if not self.running_batch.is_empty() else None
            else:
                ret = None
1329

1330
        # Handle DP attention
1331
        if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
Lianmin Zheng's avatar
Lianmin Zheng committed
1332
            ret, _ = self.prepare_dp_attn_batch(ret)
1333
1334

        return ret
1335

1336
1337
1338
1339
1340
1341
    def get_num_allocatable_reqs(self, running_bs):
        res = global_server_args_dict["max_micro_batch_size"] - running_bs
        if self.pp_size > 1:
            res = min(res, self.req_to_token_pool.available_size())
        return res

Lianmin Zheng's avatar
Lianmin Zheng committed
1342
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
1343
        # Check if the grammar is ready in the grammar queue
1344
        if self.grammar_queue:
1345
            self.move_ready_grammar_requests()
1346

Lianmin Zheng's avatar
Lianmin Zheng committed
1347
1348
        # Handle the cases where prefill is not allowed
        if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1349
            self.running_batch.batch_is_full or len(self.waiting_queue) == 0
1350
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1351
1352
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1353
        running_bs = len(self.running_batch.reqs)
1354
        # Ignore the check if self.chunked_req is not None.
1355
1356
1357
1358
1359
        # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
        # as the space for the chunked request has just been released.
        # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
        # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
        if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
Lianmin Zheng's avatar
Lianmin Zheng committed
1360
            self.running_batch.batch_is_full = True
1361
1362
            return None

1363
1364
1365
1366
1367
        if self.enable_hierarchical_cache:
            # check for completion of hierarchical cache activities to release memory
            self.tree_cache.writing_check()
            self.tree_cache.loading_check()

1368
1369
1370
        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
1371
        # Prefill policy
1372
1373
        adder = PrefillAdder(
            self.tree_cache,
1374
            self.token_to_kv_pool_allocator,
1375
1376
1377
1378
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1379
            running_bs if self.is_mixed_chunk else 0,
1380
1381
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1382
        if self.chunked_req is not None:
1383
1384
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1385

Lianmin Zheng's avatar
Lianmin Zheng committed
1386
        if self.lora_paths:
Lianmin Zheng's avatar
Lianmin Zheng committed
1387
1388
            lora_set = set([req.lora_path for req in self.running_batch.reqs])

1389
        # Get requests from the waiting queue to a new prefill batch
1390
1391
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1392
                self.lora_paths
1393
1394
1395
1396
1397
1398
1399
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1400
                self.running_batch.batch_is_full = True
1401
1402
                break

1403
            if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
Lianmin Zheng's avatar
Lianmin Zheng committed
1404
                self.running_batch.batch_is_full = True
1405
                break
1406

Byron Hsu's avatar
Byron Hsu committed
1407
1408
1409
1410
1411
1412
1413
            if self.disaggregation_mode == DisaggregationMode.PREFILL:
                # In prefill mode, prealloc queue and transfer queue can also take memory,
                # so we need to check if the available size for the actual available size.
                if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
                    self.running_batch.batch_is_full = True
                    break

1414
1415
1416
1417
            req.init_next_round_input(
                None if prefix_computed else self.tree_cache,
                self.enable_hierarchical_cache,
            )
1418

1419
1420
1421
            res = adder.add_one_req(
                req, self.chunked_req, self.enable_hierarchical_cache
            )
1422

1423
1424
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1425
1426
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
Lianmin Zheng's avatar
Lianmin Zheng committed
1427
1428
                        self.running_batch.batch_is_full = len(
                            adder.can_run_list
1429
                        ) > 0 or (not self.running_batch.is_empty())
1430
                    else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1431
                        self.running_batch.batch_is_full = True
1432
1433
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1434
        # Update waiting queue
1435
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1436
1437
        if len(can_run_list) == 0:
            return None
1438
1439
1440
1441

        if self.enable_metrics:
            # only record queue time when enable_metrics is True to avoid overhead
            for req in can_run_list:
1442
                req.queue_time_end = time.perf_counter()
1443

Lianmin Zheng's avatar
Lianmin Zheng committed
1444
1445
1446
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1447

1448
        if self.enable_hierarchical_cache:
1449
            self.tree_cache.ready_to_load_cache()
1450

1451
1452
1453
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1454

1455
1456
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1457

1458
        # Print stats
1459
        if self.attn_tp_rank == 0:
1460
            self.log_prefill_stats(adder, can_run_list, running_bs)
1461

Lianmin Zheng's avatar
Lianmin Zheng committed
1462
        # Create a new batch
1463
1464
1465
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
1466
            self.token_to_kv_pool_allocator,
1467
            self.tree_cache,
1468
            self.model_config,
1469
            self.enable_overlap,
1470
            self.spec_algorithm,
1471
            self.server_args.enable_custom_logit_processor,
1472
            chunked_req=self.chunked_req,
1473
        )
1474
        new_batch.prepare_for_extend()
1475

Lianmin Zheng's avatar
Lianmin Zheng committed
1476
        # Mixed-style chunked prefill
1477
1478
        if (
            self.is_mixed_chunk
Lianmin Zheng's avatar
Lianmin Zheng committed
1479
            and not self.running_batch.is_empty()
1480
1481
1482
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1483
1484
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1485
                self.running_batch.prepare_for_decode()
1486
1487
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1488
1489
1490
            self.running_batch = ScheduleBatch(
                reqs=[], batch_is_full=self.running_batch.batch_is_full
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1491
1492
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1493
1494
1495

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1496
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1497
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1498
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1499

1500
1501
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1502
1503
            batch.batch_is_full = False
            return batch
1504

Lianmin Zheng's avatar
Lianmin Zheng committed
1505
        # Check if decode out of memory
1506
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1507
            TEST_RETRACT and batch.batch_size() > 10
1508
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1509
1510
            old_ratio = self.new_token_ratio

1511
            retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
Lianmin Zheng's avatar
Lianmin Zheng committed
1512
            self.new_token_ratio = new_token_ratio
1513

Lianmin Zheng's avatar
Lianmin Zheng committed
1514
1515
1516
1517
1518
            logger.info(
                "Decode out of memory happened. "
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
1519
            self._extend_requests_to_queue(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1520
1521
        else:
            self.new_token_ratio = max(
1522
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1523
1524
1525
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1526
        if batch.batch_size() < initial_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1527
            batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1528
1529

        # Update batch tensors
1530
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1531
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1532

1533
1534
1535
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1536
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1537
1538
        self.forward_ct += 1

1539
1540
1541
1542
1543
        # Check profiler
        if (
            self.profiler_target_forward_ct
            and self.profiler_target_forward_ct <= self.forward_ct
        ):
1544
            self.send_to_tokenizer.send_pyobj(self.stop_profile())
1545

1546
1547
1548
1549
        if self.forward_sleep_time is not None:
            logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
            time.sleep(self.forward_sleep_time)

1550
        # Run forward
1551
        if self.is_generation:
1552
1553
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
1554
                if self.pp_group.is_last_rank:
1555
                    logits_output, next_token_ids, can_run_cuda_graph = (
1556
1557
1558
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
                else:
1559
                    pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
1560
1561
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
1562
                bid = model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
1563
            else:
1564
1565
1566
                (
                    logits_output,
                    next_token_ids,
1567
                    bid,
1568
                    num_accepted_tokens,
1569
                    can_run_cuda_graph,
1570
1571
1572
1573
1574
1575
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
                self.spec_num_total_accepted_tokens += (
                    num_accepted_tokens + batch.batch_size()
                )
                self.spec_num_total_forward_ct += batch.batch_size()
                self.num_generated_tokens += num_accepted_tokens
1576
1577
1578

            if self.pp_group.is_last_rank:
                batch.output_ids = next_token_ids
1579

1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
            if batch.return_logprob:
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_input_len_per_req = None
                extend_logprob_start_len_per_req = None

1592
            ret = GenerationBatchResult(
1593
1594
1595
1596
1597
1598
1599
                logits_output=logits_output if self.pp_group.is_last_rank else None,
                pp_hidden_states_proxy_tensors=(
                    pp_hidden_states_proxy_tensors
                    if not self.pp_group.is_last_rank
                    else None
                ),
                next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
1600
1601
                extend_input_len_per_req=extend_input_len_per_req,
                extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1602
                bid=bid,
1603
                can_run_cuda_graph=can_run_cuda_graph,
1604
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1605
1606
1607
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1608
1609
1610
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1611
        return ret
Chayenne's avatar
Chayenne committed
1612

1613
1614
1615
1616
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
1617
        launch_done: Optional[threading.Event] = None,
1618
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1619
        if batch.forward_mode.is_decode():
1620
            self.process_batch_result_decode(batch, result, launch_done)
1621
        elif batch.forward_mode.is_extend():
1622
            self.process_batch_result_prefill(batch, result, launch_done)
1623
1624
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1625
                self.tp_worker.resolve_last_batch_result(launch_done)
1626
                self.set_next_batch_sampling_info_done(batch)
1627
        elif batch.forward_mode.is_dummy_first():
1628
            self.set_next_batch_sampling_info_done(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1629

1630
1631
1632
1633
1634
1635
1636
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

1637
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
1638
1639
1640
1641
        return self.prepare_dp_attn_batch_raw(
            local_batch,
            dp_size=self.server_args.dp_size,
            attn_tp_size=self.attn_tp_size,
1642
            moe_dense_tp_size=self.server_args.moe_dense_tp_size,
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
            tp_cpu_group=self.tp_cpu_group,
            get_idle_batch=self.get_idle_batch,
            disable_cuda_graph=self.server_args.disable_cuda_graph,
            spec_algorithm=self.spec_algorithm,
            speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
        )

    @staticmethod
    def prepare_dp_attn_batch_raw(
        local_batch: ScheduleBatch,
        dp_size,
        attn_tp_size: int,
1655
        moe_dense_tp_size: Optional[int],
1656
1657
1658
1659
1660
1661
        tp_cpu_group,
        get_idle_batch,
        disable_cuda_graph: bool,
        spec_algorithm,
        speculative_num_draft_tokens,
    ):
1662
1663
1664
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
1665
            num_tokens_for_logprob = 0
1666
1667
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
1668
1669
            if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
                num_tokens = num_tokens * speculative_num_draft_tokens
1670
            num_tokens_for_logprob = num_tokens
1671
1672
        else:
            num_tokens = local_batch.extend_num_tokens
1673
            num_tokens_for_logprob = sum(
Lianmin Zheng's avatar
Lianmin Zheng committed
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
                [
                    # We should have at least 1 token for sample in every case.
                    max(extend_len - logprob_start_len, 1)
                    for logprob_start_len, extend_len in zip(
                        local_batch.extend_logprob_start_lens, local_batch.extend_lens
                    )
                ]
            )

        if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
            can_cuda_graph = 1
        else:
            can_cuda_graph = 0

1688
        if not spec_algorithm.is_none():
1689
            # TODO(sang): Support cuda graph when idle batch is there.
Lianmin Zheng's avatar
Lianmin Zheng committed
1690
1691
            if local_batch is None or local_batch.forward_mode.is_idle():
                can_cuda_graph = 0
1692

Lianmin Zheng's avatar
Lianmin Zheng committed
1693
1694
1695
1696
1697
1698
1699
        is_extend_in_batch = (
            local_batch.forward_mode.is_extend() if local_batch else False
        )
        local_info = torch.tensor(
            [
                num_tokens,
                can_cuda_graph,
1700
                num_tokens_for_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
1701
1702
1703
1704
1705
                is_extend_in_batch,
            ],
            dtype=torch.int64,
        )
        global_info = torch.empty(
1706
            (dp_size, attn_tp_size, 4),
Lianmin Zheng's avatar
Lianmin Zheng committed
1707
1708
            dtype=torch.int64,
        )
1709
        torch.distributed.all_gather_into_tensor(
Lianmin Zheng's avatar
Lianmin Zheng committed
1710
1711
            global_info.flatten(),
            local_info,
1712
            group=tp_cpu_group,
1713
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1714
1715
1716
1717
        global_num_tokens = global_info[:, 0, 0].tolist()
        can_cuda_graph = min(global_info[:, 0, 1].tolist())
        global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
        is_extend_in_batch = global_info[:, 0, 3].tolist()
1718

Lianmin Zheng's avatar
Lianmin Zheng committed
1719
        if local_batch is None and max(global_num_tokens) > 0:
1720
            local_batch = get_idle_batch()
1721
1722

        if local_batch is not None:
1723
1724
1725
1726
1727
1728
1729
1730
1731
            # TODO: handle the case when moe_dense_tp_size != 1
            if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
                local_batch.global_num_tokens = [num_tokens]
                local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
            else:
                local_batch.global_num_tokens = global_num_tokens
                local_batch.global_num_tokens_for_logprob = (
                    global_num_tokens_for_logprob
                )
1732

1733
            # Check forward mode for cuda graph
1734
            if not disable_cuda_graph:
Lianmin Zheng's avatar
Lianmin Zheng committed
1735
                local_batch.can_run_dp_cuda_graph = can_cuda_graph
1736

Lianmin Zheng's avatar
Lianmin Zheng committed
1737
        return local_batch, any(is_extend_in_batch)
1738
1739
1740
1741
1742

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
1743
            self.token_to_kv_pool_allocator,
1744
1745
1746
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
1747
            self.spec_algorithm,
1748
            self.server_args.enable_custom_logit_processor,
1749
1750
1751
1752
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1753
1754
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
1755

1756
        num_ready_reqs = 0
1757
        num_abort_reqs = 0
1758
1759
        for req in self.grammar_queue:
            try:
1760
1761
1762
                req.grammar = req.grammar.result(timeout=0.03)
                if req.grammar:
                    self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
1763
1764
                num_ready_reqs += 1
            except futures._base.TimeoutError:
1765
1766
1767
                req.grammar_wait_ct += 1
                if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
                    num_abort_reqs = 1
1768
1769
                break

1770
        if self.server_args.enable_dp_attention:
1771
1772
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
1773
        else:
1774
1775
1776
1777
1778
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
1779
            tensor = torch.tensor([num_ready_reqs, num_abort_reqs], dtype=torch.int32)
1780
1781
1782
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
1783
1784
            num_ready_reqs_max, num_abort_reqs_max = tensor.tolist()

1785
            for i in range(num_ready_reqs, num_ready_reqs_max):
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
                req = self.grammar_queue[i]
                req.grammar = req.grammar.result()
                if req.grammar:
                    self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())

            for i in range(num_ready_reqs, num_ready_reqs + num_abort_reqs_max):
                req = self.grammar_queue[i]
                req.grammar.cancel()
                req.grammar = None
                error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
                logger.error(error_msg)
                req.finished_reason = FINISH_ABORT(
                    error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
                )
            num_ready_reqs = num_ready_reqs_max + num_abort_reqs_max
1801

1802
        self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1803
1804
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1805
1806
1807
1808
1809
1810
1811
    def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
        if batch.next_batch_sampling_info:
            if batch.next_batch_sampling_info.grammars is not None:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
                self.current_stream.synchronize()
            batch.next_batch_sampling_info.sampling_info_done.set()

1812
1813
1814
    def watchdog_thread(self):
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
        self.watchdog_last_forward_ct = 0
1815
        self.watchdog_last_time = time.perf_counter()
1816
1817

        while True:
1818
            current = time.perf_counter()
1819
1820
1821
1822
1823
1824
1825
1826
1827
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if current > self.watchdog_last_time + self.watchdog_timeout:
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)

Lianmin Zheng's avatar
Lianmin Zheng committed
1828
1829
1830
1831
1832
1833
1834
1835
1836
        if not disable_request_logging():
            # Print batch size and memory pool info to check whether there are de-sync issues.
            logger.error(
                f"{self.cur_batch.batch_size()=}, "
                f"{self.cur_batch.reqs=}, "
                f"{self.token_to_kv_pool_allocator.available_size()=}, "
                f"{self.tree_cache.evictable_size()=}, "
            )

1837
        pyspy_dump_schedulers()
Lianmin Zheng's avatar
Lianmin Zheng committed
1838
        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
1839
1840
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
1841
1842

        # Wait for some time so that the parent process can print the error.
1843
1844
1845
        time.sleep(5)
        self.parent_process.send_signal(signal.SIGQUIT)

1846
1847
1848
    def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
        success = self.flush_cache()
        return FlushCacheReqOutput(success=success)
1849

1850
    def flush_cache(self):
1851
        """Flush the memory pool and cache."""
1852
1853
1854
1855
1856
        if (
            len(self.waiting_queue) == 0
            and self.running_batch.is_empty()
            and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
        ):
1857
1858
            self.cur_batch = None
            self.last_batch = None
1859
            self.tree_cache.reset()
1860
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1861
                self.grammar_backend.reset()
1862
            self.req_to_token_pool.clear()
1863
            self.token_to_kv_pool_allocator.clear()
1864
1865
1866

            if not self.spec_algorithm.is_none():
                self.draft_worker.model_runner.req_to_token_pool.clear()
1867
                self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
1868
1869
1870
1871
1872

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
1873
1874
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
1875
1876
1877
1878
1879
1880
1881
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
1882
                f"#running-req: {len(self.running_batch.reqs)}"
1883
1884
1885
1886
            )
            if_success = False
        return if_success

1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )
        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
        return GetInternalStateReqOutput(
            internal_state=ret,
        )

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
1904
                "max_micro_batch_size",
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
1915
1916
1917
1918
1919
1920
1921
1922
            elif k == "max_micro_batch_size" and (
                v > self.max_running_requests // self.pp_size or v < 1
            ):
                logging.warning(
                    f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
                )
                if_success = False
                break
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
            logger.info(f"Global server args updated! " f"{global_server_args_dict=}")
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
    def handle_rpc_request(self, recv_req: RpcReqInput):
        # Handle RPC requests
        logger.info(
            f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
        )

        success = True
        exec = None
        try:
            func = getattr(self, recv_req.method)
            func(recv_req.parameters)
        except Exception as e:
            success = False
            exec = e
            logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")

        barrier()
        return RpcReqOutput(success, "" if not exec else str(exec))

    def save_remote_model(self, params):
        url = params["url"]

1960
        worker = self.tp_worker.worker
1961
1962
1963
1964

        worker.model_runner.save_remote_model(url)

    def save_sharded_model(self, params):
1965
        worker = self.tp_worker.worker
1966
1967
1968
1969
1970
1971
1972

        worker.model_runner.save_sharded_model(
            path=params["path"],
            pattern=params["pattern"],
            max_size=params["max_size"],
        )

1973
    def abort_request(self, recv_req: AbortReq):
Lianmin Zheng's avatar
Lianmin Zheng committed
1974
1975
        # TODO(lmzheng): abort the requests in the grammar queue.

1976
        # Delete requests in the waiting queue
Lianmin Zheng's avatar
Lianmin Zheng committed
1977
        to_del = []
1978
        for i, req in enumerate(self.waiting_queue):
Lianmin Zheng's avatar
Lianmin Zheng committed
1979
1980
            if req.rid.startswith(recv_req.rid):
                to_del.append(i)
1981

Lianmin Zheng's avatar
Lianmin Zheng committed
1982
        # Sort in reverse order to avoid index issues when deleting
Lianmin Zheng's avatar
Lianmin Zheng committed
1983
        for i in reversed(to_del):
Lianmin Zheng's avatar
Lianmin Zheng committed
1984
            req = self.waiting_queue.pop(i)
Lianmin Zheng's avatar
Lianmin Zheng committed
1985
            self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
1986
            logger.debug(f"Abort queued request. {req.rid=}")
1987
1988

        # Delete requests in the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1989
1990
1991
1992
1993
1994
        if self.cur_batch is self.running_batch or self.cur_batch is None:
            reqs = self.running_batch.reqs
        else:
            reqs = self.running_batch.reqs + self.cur_batch.reqs

        for req in reqs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1995
1996
1997
            if req.rid.startswith(recv_req.rid) and not req.finished():
                logger.debug(f"Abort running request. {req.rid=}")
                req.to_abort = True
1998

1999
2000
2001
    def _pause_engine(self) -> Tuple[List[Req], int]:
        raise NotImplementedError()

Chayenne's avatar
Chayenne committed
2002
2003
2004
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
2005
2006
2007
2008
2009
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
2010
        return UpdateWeightFromDiskReqOutput(success, message, 0)
2011

2012
2013
2014
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
2015
        return InitWeightsUpdateGroupReqOutput(success, message)
2016
2017

    def update_weights_from_distributed(
2018
2019
2020
        self,
        recv_req: UpdateWeightsFromDistributedReqInput,
    ) -> Tuple[bool, str]:
2021
2022
2023
2024
2025
2026
2027
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
2028
        return UpdateWeightsFromDistributedReqOutput(success, message)
2029

2030
2031
2032
2033
2034
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        """Update the online model parameter from tensors."""
        success, message = self.tp_worker.update_weights_from_tensor(recv_req)
        # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
        if success:
2035
2036
2037
            if recv_req.flush_cache:
                flash_cache_success = self.flush_cache()
                assert flash_cache_success, "Cache flush failed after updating weights"
2038
2039
        else:
            logger.error(message)
2040
        return UpdateWeightsFromTensorReqOutput(success, message)
2041

2042
2043
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
2044
        return GetWeightsByNameReqOutput(parameter)
2045

2046
    def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
2047
2048
2049
        self.memory_saver_adapter.check_validity(
            caller_name="release_memory_occupation"
        )
2050
2051
2052
2053
2054
        self.stashed_model_static_state = _export_static_state(
            self.tp_worker.worker.model_runner.model
        )
        self.memory_saver_adapter.pause()
        self.flush_cache()
2055
        return ReleaseMemoryOccupationReqOutput()
2056

2057
    def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
2058
        self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
2059
2060
2061
2062
2063
        self.memory_saver_adapter.resume()
        _import_static_state(
            self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
        )
        del self.stashed_model_static_state
2064
2065
        return ResumeMemoryOccupationReqOutput()

2066
2067
2068
2069
2070
2071
2072
    def slow_down(self, recv_req: SlowDownReqInput):
        t = recv_req.forward_sleep_time
        if t is not None and t <= 0:
            t = None
        self.forward_sleep_time = t
        return SlowDownReqOutput()

2073
    def profile(self, recv_req: ProfileReq):
2074
2075
        if recv_req.type == ProfileReqType.START_PROFILE:
            return self.start_profile(
2076
2077
2078
2079
2080
                recv_req.output_dir,
                recv_req.num_steps,
                recv_req.activities,
                recv_req.with_stack,
                recv_req.record_shapes,
2081
                recv_req.profile_id,
2082
            )
2083
        else:
2084
2085
2086
2087
2088
2089
2090
            return self.stop_profile()

    def start_profile(
        self,
        output_dir: Optional[str],
        num_steps: Optional[int],
        activities: Optional[List[str]],
2091
2092
        with_stack: Optional[bool],
        record_shapes: Optional[bool],
2093
        profile_id: Optional[str],
2094
    ) -> None:
2095
        if self.profiler_activities:
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
            return ProfileReqOutput(
                success=False,
                message="Profiling is already in progress. Call /stop_profile first.",
            )

        if output_dir is None:
            output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
        if activities is None:
            activities = ["CPU", "GPU"]

        self.torch_profiler_output_dir = output_dir
2107
        self.profiler_activities = activities
2108
        self.profiler_id = profile_id
2109
        logger.info(
2110
            "Profiling starts. Traces will be saved to: %s (with id %s)",
2111
            self.torch_profiler_output_dir,
2112
            self.profiler_id,
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
        )

        activity_map = {
            "CPU": torch.profiler.ProfilerActivity.CPU,
            "GPU": torch.profiler.ProfilerActivity.CUDA,
        }
        torchprof_activities = [
            activity_map[a] for a in activities if a in activity_map
        ]

        if torchprof_activities:
            self.torch_profiler = torch.profiler.profile(
                activities=torchprof_activities,
2126
2127
                with_stack=with_stack if with_stack is not None else True,
                record_shapes=record_shapes if record_shapes is not None else False,
2128
2129
2130
2131
2132
            )
            self.torch_profiler.start()

        if "MEM" in activities:
            torch.cuda.memory._record_memory_history(max_entries=100000)
2133

2134
2135
2136
        if "CUDA_PROFILER" in activities:
            torch.cuda.cudart().cudaProfilerStart()

2137
2138
2139
2140
2141
2142
        if num_steps:
            self.profiler_target_forward_ct = self.forward_ct + num_steps
            # The caller will be notified when reaching profiler_target_forward_ct
        else:
            self.profiler_target_forward_ct = None
            return ProfileReqOutput(success=True, message="Succeeded")
2143
2144

    def stop_profile(self) -> None:
2145
        if self.profiler_activities is None:
2146
2147
2148
2149
            return ProfileReqOutput(
                success=False,
                message="Profiling is not in progress. Call /start_profile first.",
            )
2150
2151
2152
2153
2154
2155
2156

        logger.info("Stop profiling...")
        if self.torch_profiler is not None:
            self.torch_profiler.stop()
            self.torch_profiler.export_chrome_trace(
                os.path.join(
                    self.torch_profiler_output_dir,
2157
                    self.profiler_id + f"-TP-{self.tp_rank}" + ".trace.json.gz",
2158
2159
2160
                )
            )

2161
        if "MEM" in self.profiler_activities:
2162
            memory_profile_path = os.path.join(
2163
                self.torch_profiler_output_dir,
2164
                self.profiler_id + f"-TP-{self.tp_rank}-memory" + ".pickle",
2165
2166
2167
2168
            )
            torch.cuda.memory._dump_snapshot(memory_profile_path)
            torch.cuda.memory._record_memory_history(enabled=None)

2169
2170
2171
        if "CUDA_PROFILER" in self.profiler_activities:
            torch.cuda.cudart().cudaProfilerStop()

2172
2173
2174
        logger.info(
            "Profiling done. Traces are saved to: %s",
            self.torch_profiler_output_dir,
2175
        )
2176
2177
        self.torch_profiler = None
        self.torch_profiler_output_dir = None
2178
        self.profiler_activities = None
2179

2180
        return ProfileReqOutput(success=True, message="Succeeded")
2181

2182
2183
    def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
        if recv_req == ExpertDistributionReq.START_RECORD:
2184
            get_global_expert_distribution_recorder().start_record()
2185
        elif recv_req == ExpertDistributionReq.STOP_RECORD:
2186
            get_global_expert_distribution_recorder().stop_record()
2187
        elif recv_req == ExpertDistributionReq.DUMP_RECORD:
2188
            get_global_expert_distribution_recorder().dump_record()
2189
2190
        else:
            raise ValueError("Unrecognized ExpertDistributionReq value")
2191
        return ExpertDistributionReqOutput()
2192

2193
    def open_session(self, recv_req: OpenSessionReqInput):
2194
2195
2196
2197
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
2198
            return OpenSessionReqOutput(session_id, False)
2199
        elif session_id is None:
2200
            logger.warning("session id is None, cannot open.")
2201
            return OpenSessionReqOutput(session_id, False)
2202
2203
2204
2205
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
2206
            return OpenSessionReqOutput(session_id, True)
2207
2208
2209
2210
2211
2212
2213
2214
2215

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

2216
2217
    def get_print_prefix(self):
        prefix = ""
2218
2219
        if self.attn_dp_rank is not None:
            prefix += f" DP{self.attn_dp_rank}"
2220
2221
2222
2223
2224
2225
        if self.server_args.tp_size > 1:
            prefix += f" TP{self.tp_rank}"
        if self.pp_size > 1:
            prefix += f" PP{self.pp_rank}"
        return prefix

2226
2227
2228
2229
2230
2231
2232
    def _publish_kv_events(self):
        if self.enable_kv_cache_events:
            events = self.tree_cache.take_events()
            if events:
                batch = KVEventBatch(ts=time.time(), events=events)
                self.kv_event_publisher.publish(batch)

2233

2234
2235
2236
2237
def is_health_check_generate_req(recv_req):
    return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")


2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
def _export_static_state(model):
    return dict(
        buffers=[
            (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
        ]
    )


def _import_static_state(model, static_params):
    self_named_buffers = dict(model.named_buffers())
    for name, tensor in static_params["buffers"]:
        self_named_buffers[name][...] = tensor


2252
2253
2254
2255
2256
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
2257
    pp_rank: int,
2258
    dp_rank: Optional[int],
2259
    pipe_writer,
2260
):
2261
    # Generate the prefix
2262
2263
2264
2265
2266
2267
2268
    prefix = ""
    if dp_rank is not None:
        prefix += f" DP{dp_rank}"
    if server_args.tp_size > 1:
        prefix += f" TP{tp_rank}"
    if server_args.pp_size > 1:
        prefix += f" PP{pp_rank}"
2269

2270
    # Config the process
2271
    kill_itself_when_parent_died()
2272
    setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
2273
    faulthandler.enable()
2274
    parent_process = psutil.Process().parent()
2275

2276
2277
2278
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
2279

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
2280
    # Configure the logger
2281
    configure_logger(server_args, prefix=prefix)
2282
    suppress_other_loggers()
2283

2284
    # Set cpu affinity to this gpu process
2285
2286
2287
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

2288
2289
2290
2291
    embedding_cache_size = 100
    if "SGLANG_VLM_CACHE_SIZE_MB" in os.environ:
        embedding_cache_size = int(os.environ["SGLANG_VLM_CACHE_SIZE_MB"])
    init_embedding_cache(embedding_cache_size * 1024 * 1024)
2292
    # Create a scheduler and run the event loop
2293
    try:
2294
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
2295
        pipe_writer.send(
Mick's avatar
Mick committed
2296
2297
2298
2299
2300
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2301
        )
Byron Hsu's avatar
Byron Hsu committed
2302
2303
2304
        disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode

        if disaggregation_mode == DisaggregationMode.NULL:
2305
2306
2307
            if server_args.pp_size > 1:
                scheduler.event_loop_pp()
            elif scheduler.enable_overlap:
Byron Hsu's avatar
Byron Hsu committed
2308
2309
2310
2311
                scheduler.event_loop_overlap()
            else:
                scheduler.event_loop_normal()
        elif disaggregation_mode == DisaggregationMode.PREFILL:
2312
2313
2314
2315
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_prefill()
            else:
                scheduler.event_loop_normal_disagg_prefill()
2316

Byron Hsu's avatar
Byron Hsu committed
2317
        elif disaggregation_mode == DisaggregationMode.DECODE:
2318
2319
2320
2321
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_decode()
            else:
                scheduler.event_loop_normal_disagg_decode()
Byron Hsu's avatar
Byron Hsu committed
2322

2323
    except Exception:
2324
2325
2326
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)