scheduler.py 101 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""A scheduler that manages a tensor parallel GPU worker."""

16
import faulthandler
17
import logging
18
import os
19
import signal
20
import sys
Lianmin Zheng's avatar
Lianmin Zheng committed
21
import threading
22
import time
23
from collections import defaultdict, deque
Lianmin Zheng's avatar
Lianmin Zheng committed
24
from concurrent import futures
25
from dataclasses import dataclass
26
from http import HTTPStatus
27
from pathlib import Path
28
from types import SimpleNamespace
29
from typing import Dict, List, Optional, Tuple, Union
30

31
import psutil
32
import setproctitle
33
import torch
34
import zmq
35
from torch.distributed import barrier
36

37
from sglang.global_config import global_config
Lianmin Zheng's avatar
Lianmin Zheng committed
38
from sglang.srt.configs.model_config import ModelConfig
39
40
41
42
from sglang.srt.constrained.base_grammar_backend import (
    INVALID_GRAMMAR_OBJ,
    create_grammar_backend,
)
Byron Hsu's avatar
Byron Hsu committed
43
44
45
46
47
from sglang.srt.disaggregation.decode import (
    DecodePreallocQueue,
    DecodeTransferQueue,
    SchedulerDisaggregationDecodeMixin,
)
48
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
Byron Hsu's avatar
Byron Hsu committed
49
50
51
52
53
54
from sglang.srt.disaggregation.prefill import (
    PrefillBootstrapQueue,
    SchedulerDisaggregationPrefillMixin,
)
from sglang.srt.disaggregation.utils import (
    DisaggregationMode,
55
    MetadataBuffers,
Byron Hsu's avatar
Byron Hsu committed
56
    ReqToMetadataIdxAllocator,
57
    TransferBackend,
58
    prepare_abort,
Byron Hsu's avatar
Byron Hsu committed
59
)
60
from sglang.srt.distributed import get_pp_group, get_world_group
xm:D's avatar
xm:D committed
61
62
63
64
65
from sglang.srt.hf_transformers_utils import (
    get_processor,
    get_tokenizer,
    get_tokenizer_from_processor,
)
66
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
67
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
68
69
70
from sglang.srt.managers.expert_distribution import (
    get_global_expert_distribution_recorder,
)
71
72
from sglang.srt.managers.io_struct import (
    AbortReq,
73
    CloseSessionReqInput,
74
    ExpertDistributionReq,
75
    ExpertDistributionReqOutput,
76
77
    FlushCacheReqInput,
    FlushCacheReqOutput,
78
79
    GetInternalStateReq,
    GetInternalStateReqOutput,
80
81
    GetWeightsByNameReqInput,
    GetWeightsByNameReqOutput,
82
    HealthCheckOutput,
83
84
    InitWeightsUpdateGroupReqInput,
    InitWeightsUpdateGroupReqOutput,
85
86
    OpenSessionReqInput,
    OpenSessionReqOutput,
87
    ProfileReq,
88
89
    ProfileReqOutput,
    ProfileReqType,
90
91
92
93
    ReleaseMemoryOccupationReqInput,
    ReleaseMemoryOccupationReqOutput,
    ResumeMemoryOccupationReqInput,
    ResumeMemoryOccupationReqOutput,
94
95
    RpcReqInput,
    RpcReqOutput,
96
97
    SetInternalStateReq,
    SetInternalStateReqOutput,
98
99
    SlowDownReqInput,
    SlowDownReqOutput,
100
101
    TokenizedEmbeddingReqInput,
    TokenizedGenerateReqInput,
Chayenne's avatar
Chayenne committed
102
103
    UpdateWeightFromDiskReqInput,
    UpdateWeightFromDiskReqOutput,
104
105
    UpdateWeightsFromDistributedReqInput,
    UpdateWeightsFromDistributedReqOutput,
106
107
    UpdateWeightsFromTensorReqInput,
    UpdateWeightsFromTensorReqOutput,
108
)
109
from sglang.srt.managers.mm_utils import init_embedding_cache
110
111
from sglang.srt.managers.schedule_batch import (
    FINISH_ABORT,
Mick's avatar
Mick committed
112
    MultimodalInputs,
113
114
    Req,
    ScheduleBatch,
115
    global_server_args_dict,
116
)
117
118
119
120
121
from sglang.srt.managers.schedule_policy import (
    AddReqResult,
    PrefillAdder,
    SchedulePolicy,
)
122
123
124
from sglang.srt.managers.scheduler_output_processor_mixin import (
    SchedulerOutputProcessorMixin,
)
125
from sglang.srt.managers.session_controller import Session
126
from sglang.srt.managers.tp_worker import TpModelWorker
127
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
128
from sglang.srt.managers.utils import validate_input_length
129
from sglang.srt.mem_cache.chunk_cache import ChunkCache
130
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
131
from sglang.srt.mem_cache.radix_cache import RadixCache
132
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
Lianmin Zheng's avatar
Lianmin Zheng committed
133
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
134
from sglang.srt.reasoning_parser import ReasoningParser
135
from sglang.srt.server_args import PortArgs, ServerArgs
136
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
137
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
138
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
139
from sglang.srt.utils import (
140
    DeepEPMode,
141
    DynamicGradMode,
142
143
    broadcast_pyobj,
    configure_logger,
Lianmin Zheng's avatar
Lianmin Zheng committed
144
    disable_request_logging,
145
    get_available_gpu_memory,
146
    get_bool_env_var,
147
    get_zmq_socket,
Lianmin Zheng's avatar
Lianmin Zheng committed
148
    kill_itself_when_parent_died,
149
    point_to_point_pyobj,
150
    pyspy_dump_schedulers,
151
    set_gpu_proc_affinity,
152
153
154
    set_random_seed,
    suppress_other_loggers,
)
155
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
156
157
158

logger = logging.getLogger(__name__)

159
# Test retract decode for debugging purposes
160
161
TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
162
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
163

164

165
166
@dataclass
class GenerationBatchResult:
167
168
169
    logits_output: Optional[LogitsProcessorOutput]
    pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
    next_token_ids: Optional[List[int]]
170
171
    extend_input_len_per_req: List[int]
    extend_logprob_start_len_per_req: List[int]
172
    bid: int
173
    can_run_cuda_graph: bool
174
175
176
177
178
179
180
181


@dataclass
class EmbeddingBatchResult:
    embeddings: torch.Tensor
    bid: int


182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
class IdleSleeper:
    """
    In setups which have long inactivity periods it is desirable to reduce
    system power consumption when sglang does nothing. This would lead not only
    to power savings, but also to more CPU thermal headroom when a request
    eventually comes. This is important in cases when multiple GPUs are connected
    as each GPU would otherwise pin one thread at 100% CPU usage.

    The simplest solution is to use zmq.Poller on all sockets that may receive
    data that needs handling immediately.
    """

    def __init__(self, sockets):
        self.poller = zmq.Poller()
        for s in sockets:
            self.poller.register(s, zmq.POLLIN)

    def maybe_sleep(self):
        self.poller.poll(1000)


Byron Hsu's avatar
Byron Hsu committed
203
204
205
206
207
class Scheduler(
    SchedulerOutputProcessorMixin,
    SchedulerDisaggregationDecodeMixin,
    SchedulerDisaggregationPrefillMixin,
):
208
209
210
211
212
213
214
215
    """A scheduler that manages a tensor parallel GPU worker."""

    def __init__(
        self,
        server_args: ServerArgs,
        port_args: PortArgs,
        gpu_id: int,
        tp_rank: int,
216
        pp_rank: int,
217
        dp_rank: Optional[int],
218
219
    ):
        # Parse args
220
        self.server_args = server_args
221
        self.tp_rank = tp_rank
222
        self.pp_rank = pp_rank
223
        self.tp_size = server_args.tp_size
224
225
        self.pp_size = server_args.pp_size
        self.dp_size = server_args.dp_size
226
227
228
        self.schedule_policy = server_args.schedule_policy
        self.lora_paths = server_args.lora_paths
        self.max_loras_per_batch = server_args.max_loras_per_batch
229
        self.enable_overlap = not server_args.disable_overlap_schedule
230
        self.skip_tokenizer_init = server_args.skip_tokenizer_init
231
        self.enable_metrics = server_args.enable_metrics
232
        self.enable_kv_cache_events = server_args.kv_events_config is not None
233
        self.stream_interval = server_args.stream_interval
234
235
236
        self.spec_algorithm = SpeculativeAlgorithm.from_string(
            server_args.speculative_algorithm
        )
237
238
        self.gpu_id = gpu_id
        self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
Lianmin Zheng's avatar
Lianmin Zheng committed
239
        self.page_size = server_args.page_size
240
241
        self.dp_size = server_args.dp_size
        self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
242
243
244
245
246
247
248
249
            compute_dp_attention_world_info(
                server_args.enable_dp_attention,
                self.tp_rank,
                self.tp_size,
                self.dp_size,
            )
        )

250
251
        # Init inter-process communication
        context = zmq.Context(2)
252
253
        self.idle_sleeper = None

254
        if self.pp_rank == 0 and self.attn_tp_rank == 0:
255
            self.recv_from_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
256
                context, zmq.PULL, port_args.scheduler_input_ipc_name, False
257
            )
258
            self.send_to_tokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
259
                context, zmq.PUSH, port_args.tokenizer_ipc_name, False
260
            )
261

262
            if server_args.skip_tokenizer_init:
263
                # Directly send to the TokenizerManager
264
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
265
                    context, zmq.PUSH, port_args.tokenizer_ipc_name, False
266
267
                )
            else:
268
                # Send to the DetokenizerManager
269
                self.send_to_detokenizer = get_zmq_socket(
Lianmin Zheng's avatar
Lianmin Zheng committed
270
                    context, zmq.PUSH, port_args.detokenizer_ipc_name, False
271
                )
272
273
274
275

            self.recv_from_rpc = get_zmq_socket(
                context, zmq.DEALER, port_args.rpc_ipc_name, False
            )
276
277
278
279
280
281
282
            if self.server_args.sleep_on_idle:
                self.idle_sleeper = IdleSleeper(
                    [
                        self.recv_from_tokenizer,
                        self.recv_from_rpc,
                    ]
                )
283
        else:
284
            self.recv_from_tokenizer = None
285
            self.recv_from_rpc = None
286
287
            self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
            self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
288
289

        # Init tokenizer
290
        self.init_tokenizer()
291

292
293
294
295
296
297
298
299
300
        # Set reasoning_parser and think_end_id if --reasoning_parser is enabled
        if self.server_args.reasoning_parser and self.tokenizer:
            reasoning_parser = ReasoningParser(
                model_type=self.server_args.reasoning_parser, stream_reasoning=False
            )
            self.tokenizer.think_end_id = self.tokenizer.encode(
                reasoning_parser.detector.think_end_token, add_special_tokens=False
            )[0]

301
302
303
304
        # Check whether overlap can be enabled
        if not self.is_generation:
            self.enable_overlap = False
            logger.info("Overlap scheduler is disabled for embedding models.")
305

306
        # Launch a tensor parallel worker
307
        if self.enable_overlap:
308
            TpWorkerClass = TpModelWorkerClient
309
310
        else:
            TpWorkerClass = TpModelWorker
311

312
        self.tp_worker = TpWorkerClass(
313
            server_args=server_args,
314
315
            gpu_id=gpu_id,
            tp_rank=tp_rank,
316
            pp_rank=pp_rank,
317
            dp_rank=dp_rank,
318
            nccl_port=port_args.nccl_port,
319
        )
320

321
        # Launch a draft worker for speculative decoding
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        if self.spec_algorithm.is_eagle():
            from sglang.srt.speculative.eagle_worker import EAGLEWorker

            self.draft_worker = EAGLEWorker(
                gpu_id=gpu_id,
                tp_rank=tp_rank,
                server_args=server_args,
                nccl_port=port_args.nccl_port,
                target_worker=self.tp_worker,
                dp_rank=dp_rank,
            )
        else:
            self.draft_worker = None

336
        # Get token and memory info from the model worker
337
338
339
340
        (
            self.max_total_num_tokens,
            self.max_prefill_tokens,
            self.max_running_requests,
341
            self.max_req_len,
342
343
            self.max_req_input_len,
            self.random_seed,
344
            self.device,
345
346
347
348
349
            worker_global_server_args_dict,
            _,
            _,
            _,
        ) = self.tp_worker.get_worker_info()
350
351
352
353
354
355
356
357
        if global_server_args_dict["max_micro_batch_size"] is None:
            global_server_args_dict["max_micro_batch_size"] = max(
                self.max_running_requests // server_args.pp_size, 1
            )

        self.tp_group = self.tp_worker.get_tp_group()
        self.tp_cpu_group = self.tp_group.cpu_group
        self.attn_tp_group = self.tp_worker.get_attention_tp_group()
358
        self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
359
360
361
        self.pp_group = get_pp_group()
        self.world_group = get_world_group()

362
        self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
363
        global_server_args_dict.update(worker_global_server_args_dict)
364
        set_random_seed(self.random_seed)
365

366
        # Print debug info
367
        if tp_rank == 0:
368
369
370
            avail_mem = get_available_gpu_memory(
                self.device, self.gpu_id, empty_cache=False
            )
371
372
373
374
375
            logger.info(
                f"max_total_num_tokens={self.max_total_num_tokens}, "
                f"chunked_prefill_size={server_args.chunked_prefill_size}, "
                f"max_prefill_tokens={self.max_prefill_tokens}, "
                f"max_running_requests={self.max_running_requests}, "
376
377
                f"context_len={self.model_config.context_len}, "
                f"available_gpu_mem={avail_mem:.2f} GB"
378
            )
379

Lianmin Zheng's avatar
Lianmin Zheng committed
380
        # Init memory pool and cache
381
        self.init_memory_pool_and_cache()
382
383
384

        # Init running status
        self.waiting_queue: List[Req] = []
385
        # The running decoding batch for continuous batching
Lianmin Zheng's avatar
Lianmin Zheng committed
386
        self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
387
        # The current forward batch
Lianmin Zheng's avatar
Lianmin Zheng committed
388
        self.cur_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
389
        # The last forward batch
390
        self.last_batch: Optional[ScheduleBatch] = None
Lianmin Zheng's avatar
Lianmin Zheng committed
391
392
        self.forward_ct = 0
        self.forward_ct_decode = 0
393
        self.num_generated_tokens = 0
Lianmin Zheng's avatar
Lianmin Zheng committed
394
        self.num_prefill_tokens = 0
395
396
        self.last_decode_stats_tic = time.perf_counter()
        self.last_prefill_stats_tic = time.perf_counter()
397
        self.return_health_check_ct = 0
398
        self.current_stream = torch.get_device_module(self.device).current_stream()
399
400
        if self.device == "cpu":
            self.current_stream.synchronize = lambda: None  # No-op for CPU
401
        self.forward_sleep_time = None
402

403
        # Init session info
404
        self.sessions: Dict[str, Session] = {}
405
406
407

        # Init chunked prefill
        self.chunked_prefill_size = server_args.chunked_prefill_size
408
409
        if self.chunked_prefill_size <= 0:  # -1 means disable
            self.chunked_prefill_size = None
410
        self.chunked_req = None
411
412
413
414
        self.is_mixed_chunk = (
            self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
415
        # Init the grammar backend for constrained generation
416
        self.grammar_queue: List[Req] = []
417
        if not server_args.skip_tokenizer_init:
418
419
420
            self.grammar_backend = create_grammar_backend(
                server_args, self.tokenizer, self.model_config.vocab_size
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
421
422
        else:
            self.grammar_backend = None
423

424
        # Init schedule policy and new token estimation
425
        self.policy = SchedulePolicy(
Lianmin Zheng's avatar
Lianmin Zheng committed
426
427
428
            self.schedule_policy,
            self.tree_cache,
            self.enable_hierarchical_cache,
429
        )
430
431
432
        assert (
            server_args.schedule_conservativeness >= 0
        ), "Invalid schedule_conservativeness"
433
434
        self.init_new_token_ratio = min(
            global_config.default_init_new_token_ratio
435
436
            * server_args.schedule_conservativeness,
            1.0,
437
        )
438
439
440
441
442
443
444
445
446
447
        self.min_new_token_ratio = min(
            self.init_new_token_ratio
            * global_config.default_min_new_token_ratio_factor,
            1.0,
        )
        self.new_token_ratio_decay = (
            self.init_new_token_ratio - self.min_new_token_ratio
        ) / global_config.default_new_token_ratio_decay_steps
        self.new_token_ratio = self.init_new_token_ratio

Lianmin Zheng's avatar
Lianmin Zheng committed
448
449
450
451
        # Init watchdog thread
        self.watchdog_timeout = server_args.watchdog_timeout
        t = threading.Thread(target=self.watchdog_thread, daemon=True)
        t.start()
452
        self.parent_process = psutil.Process().parent()
Lianmin Zheng's avatar
Lianmin Zheng committed
453

454
        # Init memory saver
455
456
457
458
        self.memory_saver_adapter = TorchMemorySaverAdapter.create(
            enable=server_args.enable_memory_saver
        )

459
        # Init profiler
460
461
        self.torch_profiler = None
        self.torch_profiler_output_dir: Optional[str] = None
462
        self.profiler_activities: Optional[List[str]] = None
463
        self.profile_id: Optional[str] = None
464
        self.profiler_target_forward_ct: Optional[int] = None
465
466
467
468
469
470
471
472
        self.profiler_target_prefill_ct: Optional[int] = None
        self.profiler_target_decode_ct: Optional[int] = None
        self.profiler_prefill_ct: Optional[int] = None
        self.profiler_decode_ct: Optional[int] = None
        self.profile_by_stage: bool = False
        self.profile_steps: Optional[int] = None
        self.profile_in_progress: bool = False
        self.rpd_profiler = None
473

474
        # Init metrics stats
475
        self.init_metrics()
476
        self.init_kv_events(server_args.kv_events_config)
477

478
479
        # Init request dispatcher
        self._request_dispatcher = TypeBasedDispatcher(
480
481
482
            [
                (TokenizedGenerateReqInput, self.handle_generate_request),
                (TokenizedEmbeddingReqInput, self.handle_embedding_request),
483
                (FlushCacheReqInput, self.flush_cache_wrapped),
484
                (AbortReq, self.abort_request),
485
486
                (OpenSessionReqInput, self.open_session),
                (CloseSessionReqInput, self.close_session),
487
488
489
490
491
492
493
494
                (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
                (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
                (
                    UpdateWeightsFromDistributedReqInput,
                    self.update_weights_from_distributed,
                ),
                (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
                (GetWeightsByNameReqInput, self.get_weights_by_name),
495
496
                (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
                (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
497
                (SlowDownReqInput, self.slow_down),
498
                (ProfileReq, self.profile),
499
                (GetInternalStateReq, self.get_internal_state),
500
                (SetInternalStateReq, self.set_internal_state),
501
                (RpcReqInput, self.handle_rpc_request),
502
                (ExpertDistributionReq, self.expert_distribution_handle),
503
504
505
            ]
        )

Byron Hsu's avatar
Byron Hsu committed
506
507
508
509
510
        self.disaggregation_mode = DisaggregationMode(
            self.server_args.disaggregation_mode
        )
        self.init_disaggregation()

511
512
513
514
    def maybe_sleep_on_idle(self):
        if self.idle_sleeper is not None:
            self.idle_sleeper.maybe_sleep()

515
516
    def init_tokenizer(self):
        server_args = self.server_args
Lianmin Zheng's avatar
Lianmin Zheng committed
517

518
        self.model_config = ModelConfig.from_server_args(server_args)
519
        self.is_generation = self.model_config.is_generation
520

521
522
523
524
525
526
527
528
529
        if server_args.skip_tokenizer_init:
            self.tokenizer = self.processor = None
        else:
            if self.model_config.is_multimodal:
                self.processor = get_processor(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
530
                    use_fast=not server_args.disable_fast_image_processor,
531
                )
xm:D's avatar
xm:D committed
532
                self.tokenizer = get_tokenizer_from_processor(self.processor)
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
            else:
                self.tokenizer = get_tokenizer(
                    server_args.tokenizer_path,
                    tokenizer_mode=server_args.tokenizer_mode,
                    trust_remote_code=server_args.trust_remote_code,
                    revision=server_args.revision,
                )

    def init_memory_pool_and_cache(self):
        server_args = self.server_args

        self.req_to_token_pool, self.token_to_kv_pool_allocator = (
            self.tp_worker.get_memory_pool()
        )

        if (
            server_args.chunked_prefill_size is not None
            and server_args.disable_radix_cache
        ):
            self.tree_cache = ChunkCache(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
555
                page_size=self.page_size,
556
557
558
559
560
561
            )
        else:
            if self.enable_hierarchical_cache:
                self.tree_cache = HiRadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
562
                    tp_cache_group=self.tp_cpu_group,
563
                    page_size=self.page_size,
564
                    hicache_ratio=server_args.hicache_ratio,
Zhiqiang Xie's avatar
Zhiqiang Xie committed
565
566
                    hicache_size=server_args.hicache_size,
                    hicache_write_policy=server_args.hicache_write_policy,
567
568
569
570
571
                )
            else:
                self.tree_cache = RadixCache(
                    req_to_token_pool=self.req_to_token_pool,
                    token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
572
                    page_size=self.page_size,
573
                    disable=server_args.disable_radix_cache,
574
                    enable_kv_cache_events=self.enable_kv_cache_events,
575
576
577
578
579
580
581
582
583
584
585
586
                )

        self.decode_mem_cache_buf_multiplier = (
            1
            if self.spec_algorithm.is_none()
            else (
                server_args.speculative_num_draft_tokens
                + (
                    server_args.speculative_eagle_topk
                    * server_args.speculative_num_steps
                )
            )
587
        )
588
589
590

    def init_metrics(self):
        self.last_gen_throughput: float = 0.0
Lianmin Zheng's avatar
Lianmin Zheng committed
591
        self.last_input_throughput: float = 0.0
592
593
594
595
596
597
598
599
600
601
602
603
604
605
        self.step_time_dict = defaultdict(list)  # Dict[batch size -> step time]
        self.spec_num_total_accepted_tokens = 0
        self.spec_num_total_forward_ct = 0
        self.cum_spec_accept_length = 0
        self.cum_spec_accept_count = 0
        self.stats = SchedulerStats()
        if self.enable_metrics:
            engine_type = "unified"
            self.metrics_collector = SchedulerMetricsCollector(
                labels={
                    "model_name": self.server_args.served_model_name,
                    "engine_type": engine_type,
                },
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
606

607
608
    def init_kv_events(self, kv_events_config: Optional[str]):
        if self.enable_kv_cache_events:
609
610
611
            self.kv_event_publisher = EventPublisherFactory.create(
                kv_events_config, self.attn_dp_rank
            )
612

Byron Hsu's avatar
Byron Hsu committed
613
    def init_disaggregation(self):
614
615
616
617
        self.transfer_backend = TransferBackend(
            self.server_args.disaggregation_transfer_backend
        )

Byron Hsu's avatar
Byron Hsu committed
618
619
620
621
        if (
            self.disaggregation_mode == DisaggregationMode.DECODE
        ):  # *2 for the headroom.
            buffer_size = (self.req_to_token_pool.size) * 2
Byron Hsu's avatar
Byron Hsu committed
622
            self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
Byron Hsu's avatar
Byron Hsu committed
623
624
                buffer_size
            )
625
            self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
Byron Hsu's avatar
Byron Hsu committed
626
627
628

            # The decode requests polling kv cache
            self.disagg_decode_transfer_queue = DecodeTransferQueue(
629
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
630
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
631
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
632
633
                scheduler=self,
                tree_cache=self.tree_cache,
Byron Hsu's avatar
Byron Hsu committed
634
635
636
637
638
639
            )

            # The decode requests pending for pre-allocation
            self.disagg_decode_prealloc_queue = DecodePreallocQueue(
                req_to_token_pool=self.req_to_token_pool,
                token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
Byron Hsu's avatar
Byron Hsu committed
640
641
642
643
644
                draft_token_to_kv_pool=(
                    None
                    if self.draft_worker is None
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
645
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
646
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
647
648
649
                scheduler=self,
                transfer_queue=self.disagg_decode_transfer_queue,
                tree_cache=self.tree_cache,
650
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
651
652
653
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
654
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
655
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
656
657
658
659

            # Metric for pre-allocation
            self.num_tokens_pre_allocated = 0

Byron Hsu's avatar
Byron Hsu committed
660
661
662
        elif self.disaggregation_mode == DisaggregationMode.PREFILL:
            # *2 for the headroom.
            buffer_size = self.max_running_requests * 2
Byron Hsu's avatar
Byron Hsu committed
663
            self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
Byron Hsu's avatar
Byron Hsu committed
664
665
                buffer_size
            )
666
            self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
Byron Hsu's avatar
Byron Hsu committed
667

Liangsheng Yin's avatar
Liangsheng Yin committed
668
            self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
Byron Hsu's avatar
Byron Hsu committed
669
                token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
Byron Hsu's avatar
Byron Hsu committed
670
671
672
673
674
                draft_token_to_kv_pool=(
                    None
                    if self.draft_worker is None
                    else self.draft_worker.model_runner.token_to_kv_pool
                ),
Byron Hsu's avatar
Byron Hsu committed
675
                req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
676
                metadata_buffers=self.disagg_metadata_buffers,
Byron Hsu's avatar
Byron Hsu committed
677
678
                tp_rank=self.tp_rank,
                tp_size=self.tp_size,
Byron Hsu's avatar
Byron Hsu committed
679
                gpu_id=self.gpu_id,
Byron Hsu's avatar
Byron Hsu committed
680
                bootstrap_port=self.server_args.disaggregation_bootstrap_port,
681
                gloo_group=self.attn_tp_cpu_group,
Byron Hsu's avatar
Byron Hsu committed
682
683
684
                max_total_num_tokens=self.max_total_num_tokens,
                decode_tp_size=self.server_args.disaggregation_decode_tp,
                decode_dp_size=self.server_args.disaggregation_decode_dp,
685
                scheduler=self,
Byron Hsu's avatar
Byron Hsu committed
686
687
688
                pp_rank=self.pp_rank,
                pp_size=self.pp_size,
                transfer_backend=self.transfer_backend,
Byron Hsu's avatar
Byron Hsu committed
689
690
            )
            # The prefill requests that are in the middle of kv sending
691
            self.disagg_prefill_inflight_queue: List[Req] = []
Byron Hsu's avatar
Byron Hsu committed
692

693
    @DynamicGradMode()
694
    def event_loop_normal(self):
695
        """A normal scheduler loop."""
696
        while True:
Lianmin Zheng's avatar
Lianmin Zheng committed
697
698
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)
699

700
            batch = self.get_next_batch_to_run()
Lianmin Zheng's avatar
Lianmin Zheng committed
701
            self.cur_batch = batch
702
703
704
705

            if batch:
                result = self.run_batch(batch)
                self.process_batch_result(batch, result)
Lianmin Zheng's avatar
Lianmin Zheng committed
706
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
707
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
708
                self.check_memory()
709
                self.new_token_ratio = self.init_new_token_ratio
710
                self.maybe_sleep_on_idle()
711
712

            self.last_batch = batch
713

714
    @DynamicGradMode()
Lianmin Zheng's avatar
Lianmin Zheng committed
715
    def event_loop_overlap(self):
716
        """A scheduler loop that overlaps the CPU processing and GPU computation."""
717
        self.result_queue = deque()
Lianmin Zheng's avatar
Lianmin Zheng committed
718
719
720
721
722
723
724

        while True:
            recv_reqs = self.recv_requests()
            self.process_input_requests(recv_reqs)

            batch = self.get_next_batch_to_run()
            self.cur_batch = batch
725

Lianmin Zheng's avatar
Lianmin Zheng committed
726
            if batch:
727
                batch.launch_done = threading.Event()
Lianmin Zheng's avatar
Lianmin Zheng committed
728
                result = self.run_batch(batch)
729
                self.result_queue.append((batch.copy(), result))
Lianmin Zheng's avatar
Lianmin Zheng committed
730

731
                if self.last_batch is None:
732
                    # Create a dummy first batch to start the pipeline for overlap schedule.
733
734
735
736
737
738
                    # It is now used for triggering the sampling_info_done event.
                    tmp_batch = ScheduleBatch(
                        reqs=None,
                        forward_mode=ForwardMode.DUMMY_FIRST,
                        next_batch_sampling_info=self.tp_worker.cur_sampling_info,
                    )
739
                    self.process_batch_result(tmp_batch, None, batch.launch_done)
740

Lianmin Zheng's avatar
Lianmin Zheng committed
741
            if self.last_batch:
742
                # Process the results of the last batch
743
                tmp_batch, tmp_result = self.result_queue.popleft()
744
745
746
                tmp_batch.next_batch_sampling_info = (
                    self.tp_worker.cur_sampling_info if batch else None
                )
747
748
749
750
                # NOTE: we should use current launched batch's launch_done event Instead of the last batch's
                self.process_batch_result(
                    tmp_batch, tmp_result, batch.launch_done if batch else None
                )
Lianmin Zheng's avatar
Lianmin Zheng committed
751
            elif batch is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
752
                # When the server is idle, do self-check and re-init some states
Lianmin Zheng's avatar
Lianmin Zheng committed
753
                self.check_memory()
754
                self.new_token_ratio = self.init_new_token_ratio
755
                self.maybe_sleep_on_idle()
Lianmin Zheng's avatar
Lianmin Zheng committed
756
757
758

            self.last_batch = batch

759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
    @DynamicGradMode()
    def event_loop_pp(self):
        """A non-overlap scheduler loop for pipeline parallelism."""
        mbs = [None] * self.pp_size
        last_mbs = [None] * self.pp_size
        self.running_mbs = [
            ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
        ]
        bids = [None] * self.pp_size
        pp_outputs: Optional[PPProxyTensors] = None
        while True:
            server_is_idle = True
            for mb_id in range(self.pp_size):
                self.running_batch = self.running_mbs[mb_id]
                self.last_batch = last_mbs[mb_id]

                recv_reqs = self.recv_requests()
                self.process_input_requests(recv_reqs)
                mbs[mb_id] = self.get_next_batch_to_run()
                self.running_mbs[mb_id] = self.running_batch

                self.cur_batch = mbs[mb_id]
                if self.cur_batch:
                    server_is_idle = False
                    result = self.run_batch(self.cur_batch)

785
                # (last rank) send the outputs to the next step
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
                if self.pp_group.is_last_rank:
                    if self.cur_batch:
                        next_token_ids, bids[mb_id] = (
                            result.next_token_ids,
                            result.bid,
                        )
                        pp_outputs = PPProxyTensors(
                            {
                                "next_token_ids": next_token_ids,
                            }
                        )
                        # send the output from the last round to let the next stage worker run post processing
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                # receive outputs and post-process (filter finished reqs) the coming microbatch
                next_mb_id = (mb_id + 1) % self.pp_size
                next_pp_outputs = None
                if mbs[next_mb_id] is not None:
                    next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
                        self.pp_group.recv_tensor_dict(
                            all_gather_group=self.attn_tp_group
                        )
                    )
                    mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
                    output_result = GenerationBatchResult(
                        logits_output=None,
                        pp_hidden_states_proxy_tensors=None,
                        next_token_ids=next_pp_outputs["next_token_ids"],
                        extend_input_len_per_req=None,
                        extend_logprob_start_len_per_req=None,
                        bid=bids[next_mb_id],
820
                        can_run_cuda_graph=result.can_run_cuda_graph,
821
822
823
824
                    )
                    self.process_batch_result(mbs[next_mb_id], output_result)
                    last_mbs[next_mb_id] = mbs[next_mb_id]

825
                # (not last rank)
826
827
828
                if not self.pp_group.is_last_rank:
                    if self.cur_batch:
                        bids[mb_id] = result.bid
829
830
                    # carry the outputs to the next stage
                    # send the outputs from the last round to let the next stage worker run post processing
831
832
833
834
835
836
837
                    if pp_outputs:
                        self.pp_group.send_tensor_dict(
                            pp_outputs.tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                    # send out reqs to the next stage
838
                    dp_offset = self.attn_dp_rank * self.attn_tp_size
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
                    if self.attn_tp_rank == 0:
                        point_to_point_pyobj(
                            recv_reqs,
                            self.pp_rank * self.tp_size + dp_offset,
                            self.world_group.cpu_group,
                            self.pp_rank * self.tp_size + dp_offset,
                            (self.pp_rank + 1) * self.tp_size + dp_offset,
                        )

                    # send out proxy tensors to the next stage
                    if self.cur_batch:
                        self.pp_group.send_tensor_dict(
                            result.pp_hidden_states_proxy_tensors,
                            all_gather_group=self.attn_tp_group,
                        )

                pp_outputs = next_pp_outputs

            # When the server is idle, self-check and re-init some states
            if server_is_idle:
                self.check_memory()
                self.new_token_ratio = self.init_new_token_ratio
861
                self.maybe_sleep_on_idle()
862

863
864
    def recv_requests(self) -> List[Req]:
        """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
        if self.pp_rank == 0:
            if self.attn_tp_rank == 0:
                recv_reqs = []

                while True:
                    try:
                        recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_req)

                while True:
                    try:
                        recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
                    except zmq.ZMQError:
                        break
                    recv_reqs.append(recv_rpc)
            else:
                recv_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
884
        else:
885
            if self.attn_tp_rank == 0:
886
                dp_offset = self.attn_dp_rank * self.attn_tp_size
887
888
889
890
891
892
893
894
895
                recv_reqs = point_to_point_pyobj(
                    [],
                    self.pp_rank * self.tp_size + dp_offset,
                    self.world_group.cpu_group,
                    (self.pp_rank - 1) * self.tp_size + dp_offset,
                    self.pp_rank * self.tp_size + dp_offset,
                )
            else:
                recv_reqs = None
896

897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
        if self.server_args.enable_dp_attention:
            if self.attn_tp_rank == 0:
                work_reqs = [
                    req
                    for req in recv_reqs
                    if isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
                control_reqs = [
                    req
                    for req in recv_reqs
                    if not isinstance(
                        req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
                    )
                ]
            else:
                work_reqs = None
                control_reqs = None

            if self.attn_tp_size != 1:
                work_reqs = broadcast_pyobj(
                    work_reqs,
920
                    self.attn_tp_group.rank,
921
                    self.attn_tp_cpu_group,
922
                    src=self.attn_tp_group.ranks[0],
923
924
925
                )
            if self.tp_size != 1:
                control_reqs = broadcast_pyobj(
926
927
928
929
                    control_reqs,
                    self.tp_group.rank,
                    self.tp_cpu_group,
                    src=self.tp_group.ranks[0],
930
931
932
                )
            recv_reqs = work_reqs + control_reqs
        elif self.tp_size != 1:
933
934
935
936
937
938
            recv_reqs = broadcast_pyobj(
                recv_reqs,
                self.tp_group.rank,
                self.tp_cpu_group,
                src=self.tp_group.ranks[0],
            )
939
940
        return recv_reqs

Lianmin Zheng's avatar
Lianmin Zheng committed
941
    def process_input_requests(self, recv_reqs: List):
942
        for recv_req in recv_reqs:
943
944
            # If it is a health check generation request and there are running requests, ignore it.
            if is_health_check_generate_req(recv_req) and (
Lianmin Zheng's avatar
Lianmin Zheng committed
945
                self.chunked_req is not None or not self.running_batch.is_empty()
946
947
948
949
            ):
                self.return_health_check_ct += 1
                continue

950
            output = self._request_dispatcher(recv_req)
951
            if output is not None:
952
953
954
955
956
                if isinstance(output, RpcReqOutput):
                    if self.recv_from_rpc is not None:
                        self.recv_from_rpc.send_pyobj(output)
                else:
                    self.send_to_tokenizer.send_pyobj(output)
957
958
959
960
961

    def handle_generate_request(
        self,
        recv_req: TokenizedGenerateReqInput,
    ):
962
        # Create a new request
963
964
965
966
967
        if (
            recv_req.session_params is None
            or recv_req.session_params.id is None
            or recv_req.session_params.id not in self.sessions
        ):
Rin Intachuen's avatar
Rin Intachuen committed
968
969
970
971
972
973
            if recv_req.input_embeds is not None:
                # Generate fake input_ids based on the length of input_embeds
                seq_length = len(recv_req.input_embeds)
                fake_input_ids = [1] * seq_length
                recv_req.input_ids = fake_input_ids

974
975
976
977
            if recv_req.bootstrap_port is None:
                # Use default bootstrap port
                recv_req.bootstrap_port = self.server_args.disaggregation_bootstrap_port

978
979
980
981
982
            req = Req(
                recv_req.rid,
                recv_req.input_text,
                recv_req.input_ids,
                recv_req.sampling_params,
Lianmin Zheng's avatar
Lianmin Zheng committed
983
984
                return_logprob=recv_req.return_logprob,
                top_logprobs_num=recv_req.top_logprobs_num,
985
                token_ids_logprob=recv_req.token_ids_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
986
                stream=recv_req.stream,
987
                lora_path=recv_req.lora_path,
Rin Intachuen's avatar
Rin Intachuen committed
988
                input_embeds=recv_req.input_embeds,
Lianmin Zheng's avatar
Lianmin Zheng committed
989
                custom_logit_processor=recv_req.custom_logit_processor,
990
                return_hidden_states=recv_req.return_hidden_states,
991
                eos_token_ids=self.model_config.hf_eos_token_id,
992
                bootstrap_host=recv_req.bootstrap_host,
993
                bootstrap_port=recv_req.bootstrap_port,
994
                bootstrap_room=recv_req.bootstrap_room,
995
                data_parallel_rank=recv_req.data_parallel_rank,
996
997
            )
            req.tokenizer = self.tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
998

999
1000
1001
            if self.disaggregation_mode != DisaggregationMode.NULL:
                # Invalid request for disaggregated mode
                if recv_req.bootstrap_room is None:
1002
                    error_msg = (
1003
1004
1005
                        f"Invalid request: Disaggregated request received without "
                        f"boostrap room id. {req.rid=}"
                    )
1006
1007
                    logger.error(error_msg)
                    prepare_abort(req, error_msg)
1008
1009
1010
                    self.stream_output([req], req.return_logprob)
                    return

1011
1012
1013
1014
            if (
                recv_req.session_params is not None
                and recv_req.session_params.id is not None
            ):
1015
                req.finished_reason = FINISH_ABORT(
1016
                    f"Invalid request: session id {recv_req.session_params.id} does not exist"
1017
                )
1018
                self._add_request_to_queue(req)
1019
1020
                return
        else:
1021
1022
            # Create a new request from a previous session
            session = self.sessions[recv_req.session_params.id]
1023
            req = session.create_req(recv_req, self.tokenizer)
1024
            if isinstance(req.finished_reason, FINISH_ABORT):
1025
                self._add_request_to_queue(req)
1026
                return
1027

1028
        # Handle multimodal inputs
Mick's avatar
Mick committed
1029
1030
        if recv_req.mm_inputs is not None:
            image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
1031
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
1032
            req.origin_input_ids = self.pad_input_ids_func(
1033
                req.origin_input_ids, image_inputs
1034
            )
1035
            req.extend_image_inputs(image_inputs)
1036

1037
            if len(req.origin_input_ids) >= self.max_req_input_len:
1038
1039
1040
1041
1042
                req.set_finish_with_abort(
                    error_msg=(
                        "Multimodal prompt is too long after expanding multimodal tokens. "
                        f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                    )
1043
                )
1044
                self._add_request_to_queue(req)
1045
1046
                return

1047
        # Validate prompt length
1048
1049
1050
1051
1052
1053
        error_msg = validate_input_length(
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
        if error_msg:
1054
            req.set_finish_with_abort(error_msg)
1055
            self._add_request_to_queue(req)
1056
            return
1057

1058
        # Copy more attributes
1059
        if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
1060
1061
1062
1063
1064
            # By default, only return the logprobs for output tokens
            req.logprob_start_len = len(req.origin_input_ids) - 1
        else:
            req.logprob_start_len = recv_req.logprob_start_len

1065
        if req.logprob_start_len >= len(req.origin_input_ids):
1066
            error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
1067
            req.logprob_start_len = len(req.origin_input_ids) - 1
1068
            req.set_finish_with_abort(error_msg)
1069
1070
1071
            self._add_request_to_queue(req)
            return

1072
1073
1074
1075
1076
1077
        req.sampling_params.max_new_tokens = min(
            (
                req.sampling_params.max_new_tokens
                if req.sampling_params.max_new_tokens is not None
                else 1 << 30
            ),
1078
            self.max_req_len - len(req.origin_input_ids) - 1,
1079
1080
        )

1081
1082
1083
1084
1085
        # Init grammar cache for this request
        add_to_grammar_queue = False
        if (
            req.sampling_params.json_schema is not None
            or req.sampling_params.regex is not None
1086
            or req.sampling_params.ebnf is not None
1087
            or req.sampling_params.structural_tag is not None
1088
1089
1090
1091
1092
1093
        ):
            assert self.grammar_backend is not None
            if req.sampling_params.json_schema is not None:
                key = ("json", req.sampling_params.json_schema)
            elif req.sampling_params.regex is not None:
                key = ("regex", req.sampling_params.regex)
1094
1095
            elif req.sampling_params.ebnf is not None:
                key = ("ebnf", req.sampling_params.ebnf)
1096
1097
            elif req.sampling_params.structural_tag:
                key = ("structural_tag", req.sampling_params.structural_tag)
1098

1099
1100
1101
1102
1103
            value, cache_hit = self.grammar_backend.get_cached_or_future_value(key)
            req.grammar = value

            if not cache_hit:
                req.grammar_key = key
1104
                add_to_grammar_queue = True
1105
1106
1107
1108
            else:
                if value is INVALID_GRAMMAR_OBJ:  # We hit a cached invalid grammar.
                    error_msg = f"Invalid grammar request with cache hit: {key=}"
                    req.set_finish_with_abort(error_msg)
1109
1110

        if add_to_grammar_queue:
1111
            req.queue_time_start = time.perf_counter()
1112
1113
            self.grammar_queue.append(req)
        else:
1114
1115
1116
            self._add_request_to_queue(req)

    def _add_request_to_queue(self, req: Req):
1117
        req.queue_time_start = time.perf_counter()
Byron Hsu's avatar
Byron Hsu committed
1118
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
Byron Hsu's avatar
Byron Hsu committed
1119
1120
1121
            self.disagg_prefill_bootstrap_queue.add(
                req, self.model_config.num_key_value_heads
            )
Byron Hsu's avatar
Byron Hsu committed
1122
1123
1124
1125
1126
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            self.disagg_decode_prealloc_queue.add(req)
        else:
            self.waiting_queue.append(req)

1127
1128
    def _extend_requests_to_queue(self, reqs: List[Req]):
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
Byron Hsu's avatar
Byron Hsu committed
1129
1130
1131
            self.disagg_prefill_bootstrap_queue.extend(
                reqs, self.model_config.num_key_value_heads
            )
1132
1133
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            # If this is a decode server, we put the request to the decode pending prealloc queue
Byron Hsu's avatar
Byron Hsu committed
1134
1135
1136
            self.disagg_decode_prealloc_queue.extend(reqs)
        else:
            self.waiting_queue.extend(reqs)
1137
1138
1139

    def handle_embedding_request(
        self,
1140
        recv_req: TokenizedEmbeddingReqInput,
1141
1142
1143
1144
1145
1146
1147
1148
1149
    ):
        req = Req(
            recv_req.rid,
            recv_req.input_text,
            recv_req.input_ids,
            recv_req.sampling_params,
        )
        req.tokenizer = self.tokenizer

1150
1151
        # Handle multimodal inputs
        if recv_req.image_inputs is not None:
Mick's avatar
Mick committed
1152
            image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
1153
1154
1155
1156
1157
1158
1159
            # Expand a single image token into multiple dummy tokens for receiving image embeddings
            req.origin_input_ids = self.pad_input_ids_func(
                req.origin_input_ids, image_inputs
            )
            req.extend_image_inputs(image_inputs)

            if len(req.origin_input_ids) >= self.max_req_input_len:
1160
1161
1162
1163
1164
                req.set_finish_with_abort(
                    error_msg=(
                        "Multimodal prompt is too long after expanding multimodal tokens. "
                        f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
                    )
1165
                )
1166
                self._add_request_to_queue(req)
1167
1168
                return

1169
        # Validate prompts length
1170
        error_msg = validate_input_length(
1171
1172
1173
1174
            req,
            self.max_req_input_len,
            self.server_args.allow_auto_truncate,
        )
1175
        if error_msg:
1176
            self._add_request_to_queue(req)
1177
            return
1178

1179
1180
        # Copy more attributes
        req.logprob_start_len = len(req.origin_input_ids) - 1
1181
        self._add_request_to_queue(req)
1182

1183
1184
1185
1186
    def log_prefill_stats(
        self,
        adder: PrefillAdder,
        can_run_list: List[Req],
1187
        running_bs: int,
1188
    ):
1189
1190
        gap_latency = time.perf_counter() - self.last_prefill_stats_tic
        self.last_prefill_stats_tic = time.perf_counter()
Lianmin Zheng's avatar
Lianmin Zheng committed
1191
1192
1193
        self.last_input_throughput = self.num_prefill_tokens / gap_latency
        self.num_prefill_tokens = 0

1194
        num_used = self.max_total_num_tokens - (
1195
1196
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
1197
1198
        )

1199
        num_new_seq = len(can_run_list)
1200
        f = (
1201
            f"Prefill batch. "
1202
            f"#new-seq: {num_new_seq}, "
1203
1204
1205
1206
            f"#new-token: {adder.log_input_tokens}, "
            f"#cached-token: {adder.log_hit_tokens}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
1207
1208
1209
1210

        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
            f += f"#queue-req: {len(self.waiting_queue)}, "
fzyzcjy's avatar
fzyzcjy committed
1211
            f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
Liangsheng Yin's avatar
Liangsheng Yin committed
1212
            f += f"input throughput (token/s): {self.last_input_throughput:.2f} "
Liangsheng Yin's avatar
Liangsheng Yin committed
1213
        else:
Liangsheng Yin's avatar
Liangsheng Yin committed
1214
            f += f"#running-req: {running_bs}, "
Liangsheng Yin's avatar
Liangsheng Yin committed
1215
1216
            f += f"#queue-req: {len(self.waiting_queue)}"

1217
        logger.info(f)
1218
1219

        if self.enable_metrics:
1220
1221
1222
            cache_hit_rate = adder.log_hit_tokens / (
                adder.log_input_tokens + adder.log_hit_tokens
            )
1223
1224
1225
            self.stats.num_running_reqs = running_bs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
1226
1227
            self.stats.num_queue_reqs = len(self.waiting_queue)
            self.stats.cache_hit_rate = cache_hit_rate
1228
1229
1230
1231
1232
1233

            total_queue_latency = 0
            for req in can_run_list:
                total_queue_latency += req.queue_time_end - req.queue_time_start
            self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq

1234
            self.metrics_collector.log_stats(self.stats)
1235
        self._publish_kv_events()
1236

1237
1238
1239
    def log_decode_stats(
        self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
    ):
1240
1241
        batch = running_batch or self.running_batch

1242
1243
        gap_latency = time.perf_counter() - self.last_decode_stats_tic
        self.last_decode_stats_tic = time.perf_counter()
1244
1245
        self.last_gen_throughput = self.num_generated_tokens / gap_latency
        self.num_generated_tokens = 0
1246
        num_running_reqs = len(batch.reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1247
        num_used = self.max_total_num_tokens - (
1248
1249
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1250
        )
1251
1252
1253
1254
1255

        if RECORD_STEP_TIME:
            self.step_time_dict[num_running_reqs].append(
                gap_latency / self.server_args.decode_log_interval
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1256

Liangsheng Yin's avatar
Liangsheng Yin committed
1257
1258
1259
1260
1261
1262
1263
        msg = (
            f"Decode batch. "
            f"#running-req: {num_running_reqs}, "
            f"#token: {num_used}, "
            f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
        )

1264
        if self.spec_algorithm.is_none():
1265
            spec_accept_length = 0
1266
        else:
1267
            spec_accept_length = (
1268
1269
                self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
            )
1270
1271
            self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
            self.cum_spec_accept_count += self.spec_num_total_forward_ct
1272
            self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
Liangsheng Yin's avatar
Liangsheng Yin committed
1273
1274
1275
1276
1277
1278
            msg += f"accept len: {spec_accept_length:.2f}, "

        if self.disaggregation_mode == DisaggregationMode.DECODE:
            msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "

        msg += (
1279
            f"cuda graph: {can_run_cuda_graph}, "
Liangsheng Yin's avatar
Liangsheng Yin committed
1280
1281
1282
            f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
            f"#queue-req: {len(self.waiting_queue)}"
        )
1283
1284

        logger.info(msg)
1285
1286
1287
1288
        if self.enable_metrics:
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
1289
1290
            self.stats.cache_hit_rate = 0.0
            self.stats.gen_throughput = self.last_gen_throughput
1291
            self.stats.num_queue_reqs = len(self.waiting_queue)
1292
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1293
            self.stats.spec_accept_length = spec_accept_length
1294
            self.metrics_collector.log_stats(self.stats)
1295
        self._publish_kv_events()
1296

Lianmin Zheng's avatar
Lianmin Zheng committed
1297
1298
    def check_memory(self):
        available_size = (
1299
1300
            self.token_to_kv_pool_allocator.available_size()
            + self.tree_cache.evictable_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1301
        )
1302
1303
1304
1305
1306
1307
1308
        protected_size = self.tree_cache.protected_size()
        memory_leak = available_size != (
            self.max_total_num_tokens
            if not self.enable_hierarchical_cache
            else self.max_total_num_tokens - protected_size
        )
        if memory_leak:
1309
            msg = (
1310
                "token_to_kv_pool_allocator memory leak detected! "
1311
                f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1312
1313
                f"{self.token_to_kv_pool_allocator.available_size()=}\n"
                f"{self.tree_cache.evictable_size()=}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1314
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1315
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1316
1317

        if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
1318
            msg = (
1319
                "req_to_token_pool memory leak detected!"
1320
1321
                f"available_size={len(self.req_to_token_pool.free_slots)}, "
                f"total_size={self.req_to_token_pool.size}\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
1322
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1323
            raise ValueError(msg)
Lianmin Zheng's avatar
Lianmin Zheng committed
1324

1325
1326
1327
        if (
            self.enable_metrics
            and self.attn_tp_rank == 0
1328
            and time.perf_counter() > self.metrics_collector.last_log_time + 30
1329
1330
1331
        ):
            # During idle time, also collect metrics every 30 seconds.
            num_used = self.max_total_num_tokens - (
1332
                self.token_to_kv_pool_allocator.available_size()
1333
1334
                + self.tree_cache.evictable_size()
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1335
            num_running_reqs = len(self.running_batch.reqs)
1336
1337
1338
1339
1340
            self.stats.num_running_reqs = num_running_reqs
            self.stats.num_used_tokens = num_used
            self.stats.token_usage = num_used / self.max_total_num_tokens
            self.stats.gen_throughput = 0
            self.stats.num_queue_reqs = len(self.waiting_queue)
1341
            self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1342
            self.metrics_collector.log_stats(self.stats)
1343
        self._publish_kv_events()
1344

1345
    def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
1346
        # Merge the prefill batch into the running batch
1347
1348
1349
1350
1351
1352
1353
1354
        chunked_req_to_exclude = set()
        if self.chunked_req:
            # Move the chunked request out of the batch so that we can merge
            # only finished requests to running_batch.
            chunked_req_to_exclude.add(self.chunked_req)
            self.tree_cache.cache_unfinished_req(self.chunked_req)
            # chunked request keeps its rid but will get a new req_pool_idx
            self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
Lianmin Zheng's avatar
Lianmin Zheng committed
1355
        if self.last_batch and self.last_batch.forward_mode.is_extend():
1356
1357
1358
1359
            if self.last_batch.chunked_req is not None:
                # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
                # We need to discard it.
                chunked_req_to_exclude.add(self.last_batch.chunked_req)
Lianmin Zheng's avatar
Lianmin Zheng committed
1360

1361
            # Filter batch
1362
            last_bs = self.last_batch.batch_size()
1363
1364
1365
            self.last_batch.filter_batch(
                chunked_req_to_exclude=list(chunked_req_to_exclude)
            )
1366
            if self.last_batch.batch_size() < last_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1367
                self.running_batch.batch_is_full = False
1368

1369
            # Merge the new batch into the running batch
1370
            if not self.last_batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1371
                if self.running_batch.is_empty():
1372
1373
                    self.running_batch = self.last_batch
                else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1374
                    # Merge running_batch with prefill batch
1375
                    self.running_batch.merge_batch(self.last_batch)
1376

1377
1378
        new_batch = self.get_new_batch_prefill()
        if new_batch is not None:
1379
1380
1381
1382
            # Run prefill first if possible
            ret = new_batch
        else:
            # Run decode
Lianmin Zheng's avatar
Lianmin Zheng committed
1383
            if not self.running_batch.is_empty():
1384
                self.running_batch = self.update_running_batch(self.running_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1385
1386
1387
                ret = self.running_batch if not self.running_batch.is_empty() else None
            else:
                ret = None
1388

1389
        # Handle DP attention
1390
        if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
Lianmin Zheng's avatar
Lianmin Zheng committed
1391
            ret, _ = self.prepare_dp_attn_batch(ret)
1392
1393

        return ret
1394

1395
1396
1397
1398
1399
1400
    def get_num_allocatable_reqs(self, running_bs):
        res = global_server_args_dict["max_micro_batch_size"] - running_bs
        if self.pp_size > 1:
            res = min(res, self.req_to_token_pool.available_size())
        return res

Lianmin Zheng's avatar
Lianmin Zheng committed
1401
    def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
Lianmin Zheng's avatar
Lianmin Zheng committed
1402
        # Check if the grammar is ready in the grammar queue
1403
        if self.grammar_queue:
1404
            self.move_ready_grammar_requests()
1405

Lianmin Zheng's avatar
Lianmin Zheng committed
1406
1407
        # Handle the cases where prefill is not allowed
        if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1408
            self.running_batch.batch_is_full or len(self.waiting_queue) == 0
1409
        ) and self.chunked_req is None:
Lianmin Zheng's avatar
Lianmin Zheng committed
1410
1411
            return None

Lianmin Zheng's avatar
Lianmin Zheng committed
1412
        running_bs = len(self.running_batch.reqs)
1413
        # Ignore the check if self.chunked_req is not None.
1414
1415
1416
1417
1418
        # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
        # as the space for the chunked request has just been released.
        # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
        # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
        if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
Lianmin Zheng's avatar
Lianmin Zheng committed
1419
            self.running_batch.batch_is_full = True
1420
1421
            return None

1422
1423
1424
1425
1426
        if self.enable_hierarchical_cache:
            # check for completion of hierarchical cache activities to release memory
            self.tree_cache.writing_check()
            self.tree_cache.loading_check()

1427
1428
1429
        # Get priority queue
        prefix_computed = self.policy.calc_priority(self.waiting_queue)

Lianmin Zheng's avatar
Lianmin Zheng committed
1430
        # Prefill policy
1431
1432
        adder = PrefillAdder(
            self.tree_cache,
1433
            self.token_to_kv_pool_allocator,
1434
1435
1436
1437
            self.running_batch,
            self.new_token_ratio,
            self.max_prefill_tokens,
            self.chunked_prefill_size,
1438
            running_bs if self.is_mixed_chunk else 0,
1439
1440
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1441
        if self.chunked_req is not None:
1442
1443
            self.chunked_req.init_next_round_input()
            self.chunked_req = adder.add_chunked_req(self.chunked_req)
1444

Lianmin Zheng's avatar
Lianmin Zheng committed
1445
        if self.lora_paths:
Lianmin Zheng's avatar
Lianmin Zheng committed
1446
1447
            lora_set = set([req.lora_path for req in self.running_batch.reqs])

1448
        # Get requests from the waiting queue to a new prefill batch
1449
1450
        for req in self.waiting_queue:
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1451
                self.lora_paths
1452
1453
1454
1455
1456
1457
1458
                and len(
                    lora_set
                    | set([req.lora_path for req in adder.can_run_list])
                    | set([req.lora_path])
                )
                > self.max_loras_per_batch
            ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1459
                self.running_batch.batch_is_full = True
1460
1461
                break

1462
            if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
Lianmin Zheng's avatar
Lianmin Zheng committed
1463
                self.running_batch.batch_is_full = True
1464
                break
1465

Byron Hsu's avatar
Byron Hsu committed
1466
1467
1468
1469
1470
1471
1472
            if self.disaggregation_mode == DisaggregationMode.PREFILL:
                # In prefill mode, prealloc queue and transfer queue can also take memory,
                # so we need to check if the available size for the actual available size.
                if len(adder.can_run_list) >= self.req_to_token_pool.available_size():
                    self.running_batch.batch_is_full = True
                    break

1473
1474
1475
1476
            req.init_next_round_input(
                None if prefix_computed else self.tree_cache,
                self.enable_hierarchical_cache,
            )
1477

1478
1479
1480
            res = adder.add_one_req(
                req, self.chunked_req, self.enable_hierarchical_cache
            )
1481

1482
1483
            if res != AddReqResult.CONTINUE:
                if res == AddReqResult.NO_TOKEN:
1484
1485
                    if self.enable_hierarchical_cache:
                        # Set batch_is_full after making sure there are requests that can be served
Lianmin Zheng's avatar
Lianmin Zheng committed
1486
1487
                        self.running_batch.batch_is_full = len(
                            adder.can_run_list
1488
                        ) > 0 or (not self.running_batch.is_empty())
1489
                    else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1490
                        self.running_batch.batch_is_full = True
1491
1492
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1493
        # Update waiting queue
1494
        can_run_list: List[Req] = adder.can_run_list
Lianmin Zheng's avatar
Lianmin Zheng committed
1495
1496
        if len(can_run_list) == 0:
            return None
1497
1498
1499
1500

        if self.enable_metrics:
            # only record queue time when enable_metrics is True to avoid overhead
            for req in can_run_list:
1501
                req.queue_time_end = time.perf_counter()
1502

Lianmin Zheng's avatar
Lianmin Zheng committed
1503
1504
1505
        self.waiting_queue = [
            x for x in self.waiting_queue if x not in set(can_run_list)
        ]
1506

1507
        if self.enable_hierarchical_cache:
1508
            self.tree_cache.ready_to_load_cache()
1509

1510
1511
1512
        if adder.new_chunked_req is not None:
            assert self.chunked_req is None
            self.chunked_req = adder.new_chunked_req
1513

1514
1515
        if self.chunked_req:
            self.chunked_req.is_chunked += 1
Lianmin Zheng's avatar
Lianmin Zheng committed
1516

1517
        # Print stats
1518
        if self.attn_tp_rank == 0:
1519
            self.log_prefill_stats(adder, can_run_list, running_bs)
1520

Lianmin Zheng's avatar
Lianmin Zheng committed
1521
        # Create a new batch
1522
1523
1524
        new_batch = ScheduleBatch.init_new(
            can_run_list,
            self.req_to_token_pool,
1525
            self.token_to_kv_pool_allocator,
1526
            self.tree_cache,
1527
            self.model_config,
1528
            self.enable_overlap,
1529
            self.spec_algorithm,
1530
            self.server_args.enable_custom_logit_processor,
1531
            chunked_req=self.chunked_req,
1532
        )
1533
        new_batch.prepare_for_extend()
1534

Lianmin Zheng's avatar
Lianmin Zheng committed
1535
        # Mixed-style chunked prefill
1536
1537
        if (
            self.is_mixed_chunk
Lianmin Zheng's avatar
Lianmin Zheng committed
1538
            and not self.running_batch.is_empty()
1539
1540
1541
            and not (new_batch.return_logprob or self.running_batch.return_logprob)
        ):
            # TODO (lianmin): support return_logprob + mixed chunked prefill
1542
1543
            self.running_batch.filter_batch()
            if not self.running_batch.is_empty():
1544
                self.running_batch.prepare_for_decode()
1545
1546
                new_batch.mix_with_running(self.running_batch)
                new_batch.decoding_reqs = self.running_batch.reqs
Lianmin Zheng's avatar
Lianmin Zheng committed
1547
1548
1549
            self.running_batch = ScheduleBatch(
                reqs=[], batch_is_full=self.running_batch.batch_is_full
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1550
1551
        else:
            new_batch.decoding_reqs = None
Lianmin Zheng's avatar
Lianmin Zheng committed
1552
1553
1554

        return new_batch

Lianmin Zheng's avatar
Lianmin Zheng committed
1555
    def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
1556
        """Update the current running decoding batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1557
        initial_bs = batch.batch_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1558

1559
1560
        batch.filter_batch()
        if batch.is_empty():
Lianmin Zheng's avatar
Lianmin Zheng committed
1561
1562
            batch.batch_is_full = False
            return batch
1563

Lianmin Zheng's avatar
Lianmin Zheng committed
1564
        # Check if decode out of memory
1565
        if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
1566
            TEST_RETRACT and batch.batch_size() > 10
1567
        ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1568
1569
            old_ratio = self.new_token_ratio

1570
            retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
Lianmin Zheng's avatar
Lianmin Zheng committed
1571
            self.new_token_ratio = new_token_ratio
1572

Lianmin Zheng's avatar
Lianmin Zheng committed
1573
            logger.info(
1574
                "KV cache pool is full. Retract requests. "
Lianmin Zheng's avatar
Lianmin Zheng committed
1575
1576
1577
                f"#retracted_reqs: {len(retracted_reqs)}, "
                f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
            )
1578
            self._extend_requests_to_queue(retracted_reqs)
Lianmin Zheng's avatar
Lianmin Zheng committed
1579
1580
        else:
            self.new_token_ratio = max(
1581
                self.new_token_ratio - self.new_token_ratio_decay,
Lianmin Zheng's avatar
Lianmin Zheng committed
1582
1583
1584
                self.min_new_token_ratio,
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1585
        if batch.batch_size() < initial_bs:
Lianmin Zheng's avatar
Lianmin Zheng committed
1586
            batch.batch_is_full = False
Lianmin Zheng's avatar
Lianmin Zheng committed
1587
1588

        # Update batch tensors
1589
        batch.prepare_for_decode()
Lianmin Zheng's avatar
Lianmin Zheng committed
1590
        return batch
Lianmin Zheng's avatar
Lianmin Zheng committed
1591

1592
1593
1594
    def run_batch(
        self, batch: ScheduleBatch
    ) -> Union[GenerationBatchResult, EmbeddingBatchResult]:
1595
        """Run a batch."""
Lianmin Zheng's avatar
Lianmin Zheng committed
1596
1597
        self.forward_ct += 1

1598
1599
        # Whether to run the profiler
        self._profile_batch_predicate(batch)
1600
1601
1602
1603
        if self.forward_sleep_time is not None:
            logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
            time.sleep(self.forward_sleep_time)

1604
        # Run forward
1605
        if self.is_generation:
1606
1607
            if self.spec_algorithm.is_none():
                model_worker_batch = batch.get_model_worker_batch()
1608
                if self.pp_group.is_last_rank:
1609
                    logits_output, next_token_ids, can_run_cuda_graph = (
1610
1611
1612
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
                else:
1613
                    pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
1614
1615
                        self.tp_worker.forward_batch_generation(model_worker_batch)
                    )
1616
                bid = model_worker_batch.bid
Lianmin Zheng's avatar
Lianmin Zheng committed
1617
            else:
1618
1619
1620
                (
                    logits_output,
                    next_token_ids,
1621
                    bid,
1622
                    num_accepted_tokens,
1623
                    can_run_cuda_graph,
1624
                ) = self.draft_worker.forward_batch_speculative_generation(batch)
1625
1626
1627
                bs = batch.batch_size()
                self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
                self.spec_num_total_forward_ct += bs
1628
                self.num_generated_tokens += num_accepted_tokens
1629
1630
1631

            if self.pp_group.is_last_rank:
                batch.output_ids = next_token_ids
1632

1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
            # These 2 values are needed for processing the output, but the values can be
            # modified by overlap schedule. So we have to copy them here so that
            # we can use the correct values in output processing.
            if batch.return_logprob:
                extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
                extend_logprob_start_len_per_req = [
                    req.extend_logprob_start_len for req in batch.reqs
                ]
            else:
                extend_input_len_per_req = None
                extend_logprob_start_len_per_req = None

1645
            ret = GenerationBatchResult(
1646
1647
1648
1649
1650
1651
1652
                logits_output=logits_output if self.pp_group.is_last_rank else None,
                pp_hidden_states_proxy_tensors=(
                    pp_hidden_states_proxy_tensors
                    if not self.pp_group.is_last_rank
                    else None
                ),
                next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
1653
1654
                extend_input_len_per_req=extend_input_len_per_req,
                extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
1655
                bid=bid,
1656
                can_run_cuda_graph=can_run_cuda_graph,
1657
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1658
1659
1660
        else:  # embedding or reward model
            model_worker_batch = batch.get_model_worker_batch()
            embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
1661
1662
1663
            ret = EmbeddingBatchResult(
                embeddings=embeddings, bid=model_worker_batch.bid
            )
1664
        return ret
Chayenne's avatar
Chayenne committed
1665

1666
1667
1668
1669
    def process_batch_result(
        self,
        batch: ScheduleBatch,
        result: Union[GenerationBatchResult, EmbeddingBatchResult],
1670
        launch_done: Optional[threading.Event] = None,
1671
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
1672
        if batch.forward_mode.is_decode():
1673
            self.process_batch_result_decode(batch, result, launch_done)
1674
        elif batch.forward_mode.is_extend():
1675
            self.process_batch_result_prefill(batch, result, launch_done)
1676
1677
        elif batch.forward_mode.is_idle():
            if self.enable_overlap:
1678
                self.tp_worker.resolve_last_batch_result(launch_done)
1679
                self.set_next_batch_sampling_info_done(batch)
1680
        elif batch.forward_mode.is_dummy_first():
1681
            self.set_next_batch_sampling_info_done(batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1682

1683
1684
1685
1686
1687
1688
1689
        if self.return_health_check_ct:
            # Return some signal for the health check.
            # This is used to prevent the health check signal being blocked by long context prefill.
            # However, one minor issue is that this code path does not check the status of detokenizer manager.
            self.return_health_check_ct -= 1
            self.send_to_tokenizer.send_pyobj(HealthCheckOutput())

1690
    def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
1691
1692
1693
1694
        return self.prepare_dp_attn_batch_raw(
            local_batch,
            dp_size=self.server_args.dp_size,
            attn_tp_size=self.attn_tp_size,
1695
            moe_dense_tp_size=self.server_args.moe_dense_tp_size,
1696
1697
1698
1699
1700
            tp_cpu_group=self.tp_cpu_group,
            get_idle_batch=self.get_idle_batch,
            disable_cuda_graph=self.server_args.disable_cuda_graph,
            spec_algorithm=self.spec_algorithm,
            speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1701
1702
1703
            enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
            enable_deepep_moe=self.server_args.enable_deepep_moe,
            deepep_mode=DeepEPMode[self.server_args.deepep_mode],
1704
1705
1706
1707
1708
1709
1710
        )

    @staticmethod
    def prepare_dp_attn_batch_raw(
        local_batch: ScheduleBatch,
        dp_size,
        attn_tp_size: int,
1711
        moe_dense_tp_size: Optional[int],
1712
1713
1714
1715
1716
        tp_cpu_group,
        get_idle_batch,
        disable_cuda_graph: bool,
        spec_algorithm,
        speculative_num_draft_tokens,
1717
1718
1719
        enable_two_batch_overlap: bool,
        enable_deepep_moe: bool,
        deepep_mode: DeepEPMode,
1720
    ):
1721
1722
1723
        # Check if other DP workers have running batches
        if local_batch is None:
            num_tokens = 0
1724
            num_tokens_for_logprob = 0
1725
1726
        elif local_batch.forward_mode.is_decode():
            num_tokens = local_batch.batch_size()
1727
1728
            if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
                num_tokens = num_tokens * speculative_num_draft_tokens
1729
            num_tokens_for_logprob = num_tokens
1730
1731
        else:
            num_tokens = local_batch.extend_num_tokens
1732
            num_tokens_for_logprob = sum(
Lianmin Zheng's avatar
Lianmin Zheng committed
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
                [
                    # We should have at least 1 token for sample in every case.
                    max(extend_len - logprob_start_len, 1)
                    for logprob_start_len, extend_len in zip(
                        local_batch.extend_logprob_start_lens, local_batch.extend_lens
                    )
                ]
            )

        if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
            can_cuda_graph = 1
        else:
            can_cuda_graph = 0

1747
        if not spec_algorithm.is_none():
1748
            # TODO(sang): Support cuda graph when idle batch is there.
Lianmin Zheng's avatar
Lianmin Zheng committed
1749
1750
            if local_batch is None or local_batch.forward_mode.is_idle():
                can_cuda_graph = 0
1751

Lianmin Zheng's avatar
Lianmin Zheng committed
1752
1753
1754
        is_extend_in_batch = (
            local_batch.forward_mode.is_extend() if local_batch else False
        )
1755
1756
1757

        tbo_preparer = TboDPAttentionPreparer()

Lianmin Zheng's avatar
Lianmin Zheng committed
1758
1759
1760
1761
        local_info = torch.tensor(
            [
                num_tokens,
                can_cuda_graph,
1762
                num_tokens_for_logprob,
Lianmin Zheng's avatar
Lianmin Zheng committed
1763
                is_extend_in_batch,
1764
1765
1766
1767
1768
1769
                *tbo_preparer.prepare_all_gather(
                    local_batch,
                    deepep_mode,
                    enable_deepep_moe,
                    enable_two_batch_overlap,
                ),
Lianmin Zheng's avatar
Lianmin Zheng committed
1770
1771
1772
1773
            ],
            dtype=torch.int64,
        )
        global_info = torch.empty(
1774
            (dp_size, attn_tp_size, 6),
Lianmin Zheng's avatar
Lianmin Zheng committed
1775
1776
            dtype=torch.int64,
        )
1777
        torch.distributed.all_gather_into_tensor(
Lianmin Zheng's avatar
Lianmin Zheng committed
1778
1779
            global_info.flatten(),
            local_info,
1780
            group=tp_cpu_group,
1781
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1782
1783
1784
1785
        global_num_tokens = global_info[:, 0, 0].tolist()
        can_cuda_graph = min(global_info[:, 0, 1].tolist())
        global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
        is_extend_in_batch = global_info[:, 0, 3].tolist()
1786

1787
1788
1789
1790
        tbo_split_seq_index, global_forward_mode = tbo_preparer.compute_output(
            global_info[:, :, 4:6]
        )

Lianmin Zheng's avatar
Lianmin Zheng committed
1791
        if local_batch is None and max(global_num_tokens) > 0:
1792
            local_batch = get_idle_batch()
1793
1794

        if local_batch is not None:
1795
1796
1797
1798
1799
1800
1801
1802
1803
            # TODO: handle the case when moe_dense_tp_size != 1
            if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
                local_batch.global_num_tokens = [num_tokens]
                local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
            else:
                local_batch.global_num_tokens = global_num_tokens
                local_batch.global_num_tokens_for_logprob = (
                    global_num_tokens_for_logprob
                )
1804
1805
            local_batch.tbo_split_seq_index = tbo_split_seq_index
            local_batch.global_forward_mode = global_forward_mode
1806

1807
            # Check forward mode for cuda graph
1808
            if not disable_cuda_graph:
Lianmin Zheng's avatar
Lianmin Zheng committed
1809
                local_batch.can_run_dp_cuda_graph = can_cuda_graph
1810

Lianmin Zheng's avatar
Lianmin Zheng committed
1811
        return local_batch, any(is_extend_in_batch)
1812
1813
1814
1815
1816

    def get_idle_batch(self):
        idle_batch = ScheduleBatch.init_new(
            [],
            self.req_to_token_pool,
1817
            self.token_to_kv_pool_allocator,
1818
1819
1820
            self.tree_cache,
            self.model_config,
            self.enable_overlap,
1821
            self.spec_algorithm,
1822
            self.server_args.enable_custom_logit_processor,
1823
1824
1825
1826
        )
        idle_batch.prepare_for_idle()
        return idle_batch

1827
1828
    def move_ready_grammar_requests(self):
        """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
1829

1830
        num_ready_reqs = 0
1831
        num_timeout_reqs = 0
1832
1833
        for req in self.grammar_queue:
            try:
1834
1835
1836
                if req.finished():  # It is aborted by AbortReq
                    num_ready_reqs += 1
                    continue
1837
                req.grammar = req.grammar.result(timeout=0.03)
1838
1839
1840
1841
1842
                self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
                if req.grammar is INVALID_GRAMMAR_OBJ:
                    req.set_finish_with_abort(
                        f"Invalid grammar request: {req.grammar_key=}"
                    )
1843
1844
                num_ready_reqs += 1
            except futures._base.TimeoutError:
1845
                req.grammar_wait_ct += 1
1846
1847
                # NOTE(lianmin): this timeout is the waiting time of the above line. It is
                # not the waiting time from it enters the grammar queue.
1848
                if req.grammar_wait_ct > GRAMMAR_TIMEOUT / 0.03:
1849
                    num_timeout_reqs = 1
1850
1851
                break

1852
        if self.server_args.enable_dp_attention:
1853
1854
            tp_size = self.attn_tp_size
            tp_group = self.attn_tp_cpu_group
1855
        else:
1856
1857
1858
1859
1860
            tp_size = self.tp_size
            tp_group = self.tp_cpu_group

        if tp_size > 1:
            # Sync across TP ranks to make sure they have the same number of ready requests
1861
            tensor = torch.tensor([num_ready_reqs, num_timeout_reqs], dtype=torch.int32)
1862
1863
1864
            torch.distributed.all_reduce(
                tensor, op=torch.distributed.ReduceOp.MAX, group=tp_group
            )
1865
            num_ready_reqs_max, num_timeout_reqs_max = tensor.tolist()
1866

1867
            for i in range(num_ready_reqs, num_ready_reqs_max):
1868
                req = self.grammar_queue[i]
1869
1870
                if req.finished():  # It is aborted by AbortReq
                    continue
1871
                req.grammar = req.grammar.result()
1872
1873
1874
1875
1876
1877
1878
1879
                self.grammar_backend.set_cache(req.grammar_key, req.grammar.copy())
                if req.grammar is INVALID_GRAMMAR_OBJ:
                    req.set_finish_with_abort(
                        f"Invalid grammar request: {req.grammar_key=}"
                    )
        else:
            num_ready_reqs_max = num_ready_reqs
            num_timeout_reqs_max = num_timeout_reqs
1880

1881
1882
1883
1884
1885
1886
1887
        for i in range(num_ready_reqs, num_ready_reqs + num_timeout_reqs_max):
            req = self.grammar_queue[i]
            req.grammar.cancel()
            error_msg = f"Grammar preprocessing timed out for {req.grammar_key=}"
            req.set_finish_with_abort(error_msg)
            self.grammar_backend.set_cache(req.grammar_key, INVALID_GRAMMAR_OBJ)
        num_ready_reqs = num_ready_reqs_max + num_timeout_reqs_max
1888

1889
        self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1890
1891
        self.grammar_queue = self.grammar_queue[num_ready_reqs:]

1892
1893
1894
1895
1896
1897
1898
    def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
        if batch.next_batch_sampling_info:
            if batch.next_batch_sampling_info.grammars is not None:
                batch.next_batch_sampling_info.update_regex_vocab_mask()
                self.current_stream.synchronize()
            batch.next_batch_sampling_info.sampling_info_done.set()

1899
1900
1901
    def watchdog_thread(self):
        """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
        self.watchdog_last_forward_ct = 0
1902
        self.watchdog_last_time = time.perf_counter()
1903
1904

        while True:
1905
            current = time.perf_counter()
1906
1907
1908
1909
1910
1911
1912
1913
1914
            if self.cur_batch is not None:
                if self.watchdog_last_forward_ct == self.forward_ct:
                    if current > self.watchdog_last_time + self.watchdog_timeout:
                        break
                else:
                    self.watchdog_last_forward_ct = self.forward_ct
                    self.watchdog_last_time = current
            time.sleep(self.watchdog_timeout // 2)

Lianmin Zheng's avatar
Lianmin Zheng committed
1915
1916
1917
1918
1919
1920
1921
1922
1923
        if not disable_request_logging():
            # Print batch size and memory pool info to check whether there are de-sync issues.
            logger.error(
                f"{self.cur_batch.batch_size()=}, "
                f"{self.cur_batch.reqs=}, "
                f"{self.token_to_kv_pool_allocator.available_size()=}, "
                f"{self.tree_cache.evictable_size()=}, "
            )

1924
        pyspy_dump_schedulers()
Lianmin Zheng's avatar
Lianmin Zheng committed
1925
        logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
1926
1927
        print(file=sys.stderr, flush=True)
        print(file=sys.stdout, flush=True)
Lianmin Zheng's avatar
Lianmin Zheng committed
1928
1929

        # Wait for some time so that the parent process can print the error.
1930
1931
1932
        time.sleep(5)
        self.parent_process.send_signal(signal.SIGQUIT)

1933
1934
1935
    def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
        success = self.flush_cache()
        return FlushCacheReqOutput(success=success)
1936

1937
    def flush_cache(self):
1938
        """Flush the memory pool and cache."""
1939
1940
1941
1942
1943
        if (
            len(self.waiting_queue) == 0
            and self.running_batch.is_empty()
            and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
        ):
1944
1945
            self.cur_batch = None
            self.last_batch = None
1946
            self.tree_cache.reset()
1947
            if self.grammar_backend:
Lianmin Zheng's avatar
Lianmin Zheng committed
1948
                self.grammar_backend.reset()
1949
            self.req_to_token_pool.clear()
1950
            self.token_to_kv_pool_allocator.clear()
1951
1952
1953

            if not self.spec_algorithm.is_none():
                self.draft_worker.model_runner.req_to_token_pool.clear()
1954
                self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
1955
1956
1957
1958
1959

            self.num_generated_tokens = 0
            self.forward_ct_decode = 0
            self.spec_num_total_accepted_tokens = 0
            self.spec_num_total_forward_ct = 0
1960
1961
            self.cum_spec_accept_length = 0
            self.cum_spec_accept_count = 0
1962
1963
1964
1965
1966
1967
1968
            torch.cuda.empty_cache()
            logger.info("Cache flushed successfully!")
            if_success = True
        else:
            logging.warning(
                f"Cache not flushed because there are pending requests. "
                f"#queue-req: {len(self.waiting_queue)}, "
Lianmin Zheng's avatar
Lianmin Zheng committed
1969
                f"#running-req: {len(self.running_batch.reqs)}"
1970
1971
1972
1973
            )
            if_success = False
        return if_success

Liangsheng Yin's avatar
Liangsheng Yin committed
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
    def get_load(self):
        # TODO(lsyin): use dynamically maintained num_waiting_tokens
        load = (
            self.max_total_num_tokens
            - self.token_to_kv_pool_allocator.available_size()
            - self.tree_cache.evictable_size()
        )
        load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
        if self.disaggregation_mode == DisaggregationMode.PREFILL:
            load += sum(
                len(req.origin_input_ids)
                for req in self.disagg_prefill_bootstrap_queue.queue
            )
        elif self.disaggregation_mode == DisaggregationMode.DECODE:
            load += sum(
                len(req.req.origin_input_ids)
                for req in self.disagg_decode_prealloc_queue.queue
            )

        return load

1995
1996
1997
1998
1999
2000
2001
2002
2003
    def get_internal_state(self, recv_req: GetInternalStateReq):
        ret = dict(global_server_args_dict)
        ret["last_gen_throughput"] = self.last_gen_throughput
        if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
            ret["avg_spec_accept_length"] = (
                self.cum_spec_accept_length / self.cum_spec_accept_count
            )
        if RECORD_STEP_TIME:
            ret["step_time_dict"] = self.step_time_dict
Liangsheng Yin's avatar
Liangsheng Yin committed
2004
2005
2006
2007

        ret["load"] = self.get_load()

        return GetInternalStateReqOutput(internal_state=ret)
2008
2009
2010
2011
2012

    def set_internal_state(self, recv_req: SetInternalStateReq):
        server_args_dict = recv_req.server_args
        args_allow_update = set(
            [
2013
                "max_micro_batch_size",
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
                "speculative_accept_threshold_single",
                "speculative_accept_threshold_acc",
            ]
        )
        if_success = True
        for k, v in server_args_dict.items():
            if k not in args_allow_update:
                logging.warning(f"Updating {k} is not supported.")
                if_success = False
                break
2024
2025
2026
2027
2028
2029
2030
2031
            elif k == "max_micro_batch_size" and (
                v > self.max_running_requests // self.pp_size or v < 1
            ):
                logging.warning(
                    f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
                )
                if_success = False
                break
2032
2033
2034
2035
2036
2037
2038
2039
2040
        if if_success:
            if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
                avg_spec_accept_length = (
                    self.cum_spec_accept_length / self.cum_spec_accept_count
                )
                logger.info(f"{avg_spec_accept_length=}")
            self.cum_spec_accept_length = self.cum_spec_accept_count = 0
            for k, v in server_args_dict.items():
                global_server_args_dict[k] = v
2041
            logger.info(f"Global server args updated! {global_server_args_dict=}")
2042
2043
2044
2045
2046
        return SetInternalStateReqOutput(
            updated=True,
            server_args=global_server_args_dict,
        )

2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
    def handle_rpc_request(self, recv_req: RpcReqInput):
        # Handle RPC requests
        logger.info(
            f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
        )

        success = True
        exec = None
        try:
            func = getattr(self, recv_req.method)
            func(recv_req.parameters)
        except Exception as e:
            success = False
            exec = e
            logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")

        barrier()
        return RpcReqOutput(success, "" if not exec else str(exec))

    def save_remote_model(self, params):
        url = params["url"]

2069
        worker = self.tp_worker.worker
2070
2071
2072
2073

        worker.model_runner.save_remote_model(url)

    def save_sharded_model(self, params):
2074
        worker = self.tp_worker.worker
2075
2076
2077
2078
2079
2080
2081

        worker.model_runner.save_sharded_model(
            path=params["path"],
            pattern=params["pattern"],
            max_size=params["max_size"],
        )

2082
2083
    def abort_request(self, recv_req: AbortReq):
        # Delete requests in the waiting queue
Lianmin Zheng's avatar
Lianmin Zheng committed
2084
        to_del = []
2085
        for i, req in enumerate(self.waiting_queue):
Lianmin Zheng's avatar
Lianmin Zheng committed
2086
2087
            if req.rid.startswith(recv_req.rid):
                to_del.append(i)
2088

Lianmin Zheng's avatar
Lianmin Zheng committed
2089
        # Sort in reverse order to avoid index issues when deleting
Lianmin Zheng's avatar
Lianmin Zheng committed
2090
        for i in reversed(to_del):
2091
2092
2093
            # Abort method 1: directly pop from the queue
            # This only works for requests that have not started anything.
            # We still need to send something back to TokenizerManager to clean up the state.
Lianmin Zheng's avatar
Lianmin Zheng committed
2094
            req = self.waiting_queue.pop(i)
Lianmin Zheng's avatar
Lianmin Zheng committed
2095
            self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
2096
            logger.debug(f"Abort queued request. {req.rid=}")
2097

2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
        # Delete the requests in the grammar queue
        for req in self.grammar_queue:
            # Abort method 2: call `set_finish_with_abort`
            # The request will still run one prefill forward pass.
            # In this case, we change the input_ids to be only one token to make this prefill cheap.
            if req.rid.startswith(recv_req.rid):
                logger.debug(f"Abort grammar queue request. {req.rid=}")
                req.grammar.cancel()
                req.set_finish_with_abort("Aborted by AbortReq.")

2108
        # Delete requests in the running batch
Lianmin Zheng's avatar
Lianmin Zheng committed
2109
2110
2111
2112
2113
2114
        if self.cur_batch is self.running_batch or self.cur_batch is None:
            reqs = self.running_batch.reqs
        else:
            reqs = self.running_batch.reqs + self.cur_batch.reqs

        for req in reqs:
Lianmin Zheng's avatar
Lianmin Zheng committed
2115
            if req.rid.startswith(recv_req.rid) and not req.finished():
2116
2117
2118
                # Abort method 3: set `to_abort=True`
                # The request will still run one decode forward pass.
                # Then we reuse all existing code to clean up the KV cache allocation.
Lianmin Zheng's avatar
Lianmin Zheng committed
2119
2120
                logger.debug(f"Abort running request. {req.rid=}")
                req.to_abort = True
2121

2122
2123
2124
    def _pause_engine(self) -> Tuple[List[Req], int]:
        raise NotImplementedError()

Chayenne's avatar
Chayenne committed
2125
2126
2127
    def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
        """In-place update of the weights from disk."""
        success, message = self.tp_worker.update_weights_from_disk(recv_req)
2128
2129
2130
2131
2132
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
2133
        return UpdateWeightFromDiskReqOutput(success, message, 0)
2134

2135
2136
2137
    def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
        """Initialize the online model parameter update group."""
        success, message = self.tp_worker.init_weights_update_group(recv_req)
2138
        return InitWeightsUpdateGroupReqOutput(success, message)
2139
2140

    def update_weights_from_distributed(
2141
2142
2143
        self,
        recv_req: UpdateWeightsFromDistributedReqInput,
    ) -> Tuple[bool, str]:
2144
2145
2146
2147
2148
2149
2150
        """Update the online model parameter."""
        success, message = self.tp_worker.update_weights_from_distributed(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
2151
        return UpdateWeightsFromDistributedReqOutput(success, message)
2152

2153
2154
2155
2156
2157
    def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
        """Update the online model parameter from tensors."""
        success, message = self.tp_worker.update_weights_from_tensor(recv_req)
        # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
        if success:
2158
2159
2160
            if recv_req.flush_cache:
                flash_cache_success = self.flush_cache()
                assert flash_cache_success, "Cache flush failed after updating weights"
2161
2162
        else:
            logger.error(message)
2163
        return UpdateWeightsFromTensorReqOutput(success, message)
2164

2165
2166
    def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
        parameter = self.tp_worker.get_weights_by_name(recv_req)
2167
        return GetWeightsByNameReqOutput(parameter)
2168

2169
    def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
2170
2171
2172
        self.memory_saver_adapter.check_validity(
            caller_name="release_memory_occupation"
        )
2173
2174
2175
2176
2177
        self.stashed_model_static_state = _export_static_state(
            self.tp_worker.worker.model_runner.model
        )
        self.memory_saver_adapter.pause()
        self.flush_cache()
2178
        return ReleaseMemoryOccupationReqOutput()
2179

2180
    def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
2181
        self.memory_saver_adapter.check_validity(caller_name="resume_memory_occupation")
2182
2183
2184
2185
2186
        self.memory_saver_adapter.resume()
        _import_static_state(
            self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
        )
        del self.stashed_model_static_state
2187
2188
        return ResumeMemoryOccupationReqOutput()

2189
2190
2191
2192
2193
2194
2195
    def slow_down(self, recv_req: SlowDownReqInput):
        t = recv_req.forward_sleep_time
        if t is not None and t <= 0:
            t = None
        self.forward_sleep_time = t
        return SlowDownReqOutput()

2196
    def profile(self, recv_req: ProfileReq):
2197
        if recv_req.type == ProfileReqType.START_PROFILE:
2198
2199
2200
2201
2202
2203
2204
2205
            if recv_req.profile_by_stage:
                return self.init_profile(
                    recv_req.output_dir,
                    recv_req.num_steps,
                    recv_req.activities,
                    recv_req.with_stack,
                    recv_req.record_shapes,
                    recv_req.profile_by_stage,
2206
                    recv_req.profile_id,
2207
2208
2209
2210
2211
2212
2213
2214
2215
                )
            else:
                self.init_profile(
                    recv_req.output_dir,
                    recv_req.num_steps,
                    recv_req.activities,
                    recv_req.with_stack,
                    recv_req.record_shapes,
                    recv_req.profile_by_stage,
2216
                    recv_req.profile_id,
2217
2218
                )
                return self.start_profile(True)
2219
        else:
2220
2221
            return self.stop_profile()

2222
    def init_profile(
2223
2224
2225
2226
        self,
        output_dir: Optional[str],
        num_steps: Optional[int],
        activities: Optional[List[str]],
2227
2228
        with_stack: Optional[bool],
        record_shapes: Optional[bool],
2229
        profile_by_stage: bool,
2230
        profile_id: str,
2231
2232
    ) -> ProfileReqOutput:
        if self.profile_in_progress:
2233
2234
2235
2236
2237
            return ProfileReqOutput(
                success=False,
                message="Profiling is already in progress. Call /stop_profile first.",
            )

2238
2239
        self.profile_by_stage = profile_by_stage

2240
2241
2242
2243
2244
2245
        if output_dir is None:
            output_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
        if activities is None:
            activities = ["CPU", "GPU"]

        self.torch_profiler_output_dir = output_dir
2246
2247
        self.torch_profiler_with_stack = with_stack
        self.torch_profiler_record_shapes = record_shapes
2248
        self.profiler_activities = activities
2249
        self.profile_id = profile_id
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269

        if num_steps:
            self.profile_steps = num_steps
            if self.profile_by_stage:
                self.profiler_target_prefill_ct = num_steps
                self.profiler_target_decode_ct = num_steps
                self.profiler_prefill_ct = 0
                self.profiler_decode_ct = 0
            else:
                self.profiler_target_forward_ct = self.forward_ct + num_steps
            # The caller will be notified when reaching profiler_target_forward_ct
        else:
            self.profiler_target_forward_ct = None

        return ProfileReqOutput(success=True, message="Succeeded")

    def start_profile(
        self, stage: Optional[ForwardMode] = None
    ) -> ProfileReqOutput | None:
        stage_str = f" for {stage.__str__()}" if stage else ""
2270
        logger.info(
2271
            f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
2272
2273
        )

2274
2275
2276
2277
        activities = self.profiler_activities
        with_stack = self.torch_profiler_with_stack
        record_shapes = self.torch_profiler_record_shapes

2278
2279
2280
2281
2282
2283
2284
2285
        activity_map = {
            "CPU": torch.profiler.ProfilerActivity.CPU,
            "GPU": torch.profiler.ProfilerActivity.CUDA,
        }
        torchprof_activities = [
            activity_map[a] for a in activities if a in activity_map
        ]

2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
        if "RPD" in activities:
            from rpdTracerControl import rpdTracerControl

            rpdTracerControl.skipCreate()

            self.rpd_profile_path = os.path.join(
                self.torch_profiler_output_dir,
                "rpd-" + str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
            )

            if self.tp_rank == 0:
                import sqlite3

                from rocpd.schema import RocpdSchema

                if os.path.exists("trace.rpd"):
                    os.unlink("trace.rpd")
                schema = RocpdSchema()
                connection = sqlite3.connect("trace.rpd")
                schema.writeSchema(connection)
                connection.commit()
                del connection
            torch.distributed.barrier(self.tp_cpu_group)

            self.rpd_profiler = rpdTracerControl()
            self.rpd_profiler.setPythonTrace(True)
            self.rpd_profiler.start()
            self.rpd_profiler.rangePush("", "rpd profile range", "")
            self.profile_in_progress = True
        elif torchprof_activities:
2316
2317
            self.torch_profiler = torch.profiler.profile(
                activities=torchprof_activities,
2318
2319
                with_stack=with_stack if with_stack is not None else True,
                record_shapes=record_shapes if record_shapes is not None else False,
2320
2321
            )
            self.torch_profiler.start()
2322
            self.profile_in_progress = True
2323
2324
2325

        if "MEM" in activities:
            torch.cuda.memory._record_memory_history(max_entries=100000)
2326
            self.profile_in_progress = True
2327

2328
2329
2330
        if "CUDA_PROFILER" in activities:
            torch.cuda.cudart().cudaProfilerStart()

2331
        return ProfileReqOutput(success=True, message="Succeeded")
2332

2333
2334
2335
2336
    def stop_profile(
        self, stage: Optional[ForwardMode] = None
    ) -> ProfileReqOutput | None:
        if not self.profile_in_progress:
2337
2338
2339
2340
            return ProfileReqOutput(
                success=False,
                message="Profiling is not in progress. Call /start_profile first.",
            )
2341

2342
2343
2344
        if not Path(self.torch_profiler_output_dir).exists():
            Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)

2345
2346
        stage_suffix = f"-{stage.__str__()}" if stage else ""
        logger.info("Stop profiling" + stage_suffix + "...")
2347
2348
2349
2350
2351
        if self.torch_profiler is not None:
            self.torch_profiler.stop()
            self.torch_profiler.export_chrome_trace(
                os.path.join(
                    self.torch_profiler_output_dir,
2352
                    self.profile_id
2353
2354
2355
                    + f"-TP-{self.tp_rank}"
                    + stage_suffix
                    + ".trace.json.gz",
2356
2357
                )
            )
2358
2359
2360
2361
2362
2363
            torch.distributed.barrier(self.tp_cpu_group)

        if self.rpd_profiler is not None:
            self.rpd_profiler.rangePop()
            self.rpd_profiler.stop()
            self.rpd_profiler.flush()
2364

2365
2366
2367
2368
2369
2370
2371
2372
2373
            torch.distributed.barrier(self.tp_cpu_group)
            if self.tp_rank == 0:
                from sglang.srt.utils import rpd_to_chrome_trace

                rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
            self.rpd_profiler = None
            self.rpd_profiler_path = None

        if self.profiler_activities is not None and "MEM" in self.profiler_activities:
2374
            memory_profile_path = os.path.join(
2375
                self.torch_profiler_output_dir,
2376
2377
2378
2379
                str(time.time())
                + f"-TP-{self.tp_rank}-memory"
                + stage_suffix
                + ".pickle",
2380
2381
2382
2383
            )
            torch.cuda.memory._dump_snapshot(memory_profile_path)
            torch.cuda.memory._record_memory_history(enabled=None)

2384
2385
2386
        if "CUDA_PROFILER" in self.profiler_activities:
            torch.cuda.cudart().cudaProfilerStop()

2387
2388
2389
        logger.info(
            "Profiling done. Traces are saved to: %s",
            self.torch_profiler_output_dir,
2390
        )
2391
        self.torch_profiler = None
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
        self.profile_in_progress = False

        return ProfileReqOutput(success=True, message="Succeeded.")

    def _profile_batch_predicate(self, batch):
        if self.profile_by_stage:
            if batch.forward_mode.is_prefill():
                if self.profiler_prefill_ct == 0:
                    self.start_profile(batch.forward_mode)
                self.profiler_prefill_ct += 1
                if self.profiler_prefill_ct > self.profiler_target_prefill_ct:
                    if self.profile_in_progress:
                        self.stop_profile(stage=ForwardMode.EXTEND)
            elif batch.forward_mode.is_decode():
                if self.profiler_decode_ct == 0:
                    if self.profile_in_progress:
                        # force trace flush
                        self.stop_profile(ForwardMode.EXTEND)
                    self.start_profile(batch.forward_mode)
                self.profiler_decode_ct += 1
                if self.profiler_decode_ct > self.profiler_target_decode_ct:
                    if self.profile_in_progress:
                        self.stop_profile(stage=ForwardMode.DECODE)
            else:
                raise RuntimeError("unsupported profile stage")
        else:
            # Check profiler
            if (
                self.profiler_target_forward_ct
                and self.profiler_target_forward_ct <= self.forward_ct
            ):
                self.stop_profile()
2424

2425
2426
    def expert_distribution_handle(self, recv_req: ExpertDistributionReq):
        if recv_req == ExpertDistributionReq.START_RECORD:
2427
            get_global_expert_distribution_recorder().start_record()
2428
        elif recv_req == ExpertDistributionReq.STOP_RECORD:
2429
            get_global_expert_distribution_recorder().stop_record()
2430
        elif recv_req == ExpertDistributionReq.DUMP_RECORD:
2431
            get_global_expert_distribution_recorder().dump_record()
2432
2433
        else:
            raise ValueError("Unrecognized ExpertDistributionReq value")
2434
        return ExpertDistributionReqOutput()
2435

2436
    def open_session(self, recv_req: OpenSessionReqInput):
2437
2438
2439
2440
        # handle error
        session_id = recv_req.session_id
        if session_id in self.sessions:
            logger.warning(f"session id {session_id} already exist, cannot open.")
2441
            return OpenSessionReqOutput(session_id, False)
2442
        elif session_id is None:
2443
            logger.warning("session id is None, cannot open.")
2444
            return OpenSessionReqOutput(session_id, False)
2445
2446
2447
2448
        else:
            self.sessions[session_id] = Session(
                recv_req.capacity_of_str_len, session_id
            )
2449
            return OpenSessionReqOutput(session_id, True)
2450
2451
2452
2453
2454
2455
2456
2457
2458

    def close_session(self, recv_req: CloseSessionReqInput):
        # handle error
        session_id = recv_req.session_id
        if session_id not in self.sessions:
            logger.warning(f"session id {session_id} does not exist, cannot delete.")
        else:
            del self.sessions[session_id]

2459
2460
    def get_print_prefix(self):
        prefix = ""
2461
2462
        if self.attn_dp_rank is not None:
            prefix += f" DP{self.attn_dp_rank}"
2463
2464
2465
2466
2467
2468
        if self.server_args.tp_size > 1:
            prefix += f" TP{self.tp_rank}"
        if self.pp_size > 1:
            prefix += f" PP{self.pp_rank}"
        return prefix

2469
2470
2471
2472
2473
2474
2475
    def _publish_kv_events(self):
        if self.enable_kv_cache_events:
            events = self.tree_cache.take_events()
            if events:
                batch = KVEventBatch(ts=time.time(), events=events)
                self.kv_event_publisher.publish(batch)

2476

2477
2478
2479
2480
def is_health_check_generate_req(recv_req):
    return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")


2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
def _export_static_state(model):
    return dict(
        buffers=[
            (name, buffer.detach().clone()) for name, buffer in model.named_buffers()
        ]
    )


def _import_static_state(model, static_params):
    self_named_buffers = dict(model.named_buffers())
    for name, tensor in static_params["buffers"]:
        self_named_buffers[name][...] = tensor


2495
2496
2497
2498
2499
def run_scheduler_process(
    server_args: ServerArgs,
    port_args: PortArgs,
    gpu_id: int,
    tp_rank: int,
2500
    pp_rank: int,
2501
    dp_rank: Optional[int],
2502
    pipe_writer,
2503
):
2504
    # Generate the prefix
2505
2506
2507
2508
2509
2510
2511
    prefix = ""
    if dp_rank is not None:
        prefix += f" DP{dp_rank}"
    if server_args.tp_size > 1:
        prefix += f" TP{tp_rank}"
    if server_args.pp_size > 1:
        prefix += f" PP{pp_rank}"
2512

2513
    # Config the process
2514
    kill_itself_when_parent_died()
2515
    setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
2516
    faulthandler.enable()
2517
    parent_process = psutil.Process().parent()
2518

2519
2520
2521
    # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
    if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
        dp_rank = int(os.environ["SGLANG_DP_RANK"])
2522

Wang Ran (汪然)'s avatar
Wang Ran (汪然) committed
2523
    # Configure the logger
2524
    configure_logger(server_args, prefix=prefix)
2525
    suppress_other_loggers()
2526

2527
    # Set cpu affinity to this gpu process
2528
2529
2530
    if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
        set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)

2531
2532
2533
2534
    embedding_cache_size = 100
    if "SGLANG_VLM_CACHE_SIZE_MB" in os.environ:
        embedding_cache_size = int(os.environ["SGLANG_VLM_CACHE_SIZE_MB"])
    init_embedding_cache(embedding_cache_size * 1024 * 1024)
2535
    # Create a scheduler and run the event loop
2536
    try:
2537
        scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
2538
        pipe_writer.send(
Mick's avatar
Mick committed
2539
2540
2541
2542
2543
            {
                "status": "ready",
                "max_total_num_tokens": scheduler.max_total_num_tokens,
                "max_req_input_len": scheduler.max_req_input_len,
            }
2544
        )
Byron Hsu's avatar
Byron Hsu committed
2545
2546
2547
        disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode

        if disaggregation_mode == DisaggregationMode.NULL:
2548
2549
2550
            if server_args.pp_size > 1:
                scheduler.event_loop_pp()
            elif scheduler.enable_overlap:
Byron Hsu's avatar
Byron Hsu committed
2551
2552
2553
2554
                scheduler.event_loop_overlap()
            else:
                scheduler.event_loop_normal()
        elif disaggregation_mode == DisaggregationMode.PREFILL:
2555
2556
2557
2558
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_prefill()
            else:
                scheduler.event_loop_normal_disagg_prefill()
2559

Byron Hsu's avatar
Byron Hsu committed
2560
        elif disaggregation_mode == DisaggregationMode.DECODE:
2561
2562
2563
2564
            if scheduler.enable_overlap:
                scheduler.event_loop_overlap_disagg_decode()
            else:
                scheduler.event_loop_normal_disagg_decode()
Byron Hsu's avatar
Byron Hsu committed
2565

2566
    except Exception:
2567
2568
2569
        traceback = get_exception_traceback()
        logger.error(f"Scheduler hit an exception: {traceback}")
        parent_process.send_signal(signal.SIGQUIT)