gpu_model_runner.py 220 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import gc
5
import itertools
6
import time
7
from collections import defaultdict
8
from collections.abc import Iterator, Sequence
9
from contextlib import contextmanager
10
from copy import copy, deepcopy
11
from functools import reduce
12
from itertools import product
13
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
14
15
16
17
18

import numpy as np
import torch
import torch.distributed
import torch.nn as nn
19
from tqdm import tqdm
20

21
import vllm.envs as envs
22
from vllm.attention import Attention, AttentionType
23
24
25
26
27
from vllm.attention.backends.abstract import (
    AttentionBackend,
    AttentionMetadata,
    MultipleOf,
)
28
from vllm.compilation.counter import compilation_counter
29
30
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
31
from vllm.config import (
32
    CompilationMode,
33
34
35
36
37
    CUDAGraphMode,
    VllmConfig,
    get_layers_from_vllm_config,
    update_config,
)
38
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
39
from vllm.distributed.eplb.eplb_state import EplbState
40
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
41
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
42
from vllm.distributed.parallel_state import (
43
    get_dcp_group,
44
45
46
47
48
49
    get_pp_group,
    get_tp_group,
    graph_capture,
    is_global_first_rank,
    prepare_communication_buffer_for_model,
)
50
from vllm.forward_context import BatchDescriptor, set_forward_context
51
from vllm.logger import init_logger
52
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
53
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
54
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
55
from vllm.model_executor.models.interfaces import (
56
    SupportsMRoPE,
57
58
59
60
61
62
63
    SupportsMultiModal,
    is_mixture_of_experts,
    supports_eagle3,
    supports_mrope,
    supports_multimodal_pruning,
    supports_transcription,
)
64
from vllm.model_executor.models.interfaces_base import (
65
66
67
68
    VllmModelForPooling,
    is_pooling_model,
    is_text_generation_model,
)
69
from vllm.multimodal import MULTIMODAL_REGISTRY
70
71
72
73
74
from vllm.multimodal.inputs import (
    BatchedTensorInputs,
    MultiModalKwargsItem,
    PlaceholderRange,
)
75
from vllm.multimodal.utils import group_mm_kwargs_by_modality
76
from vllm.pooling_params import PoolingParams
77
from vllm.sampling_params import SamplingType
78
from vllm.sequence import IntermediateTensors
79
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
80
from vllm.utils import length_from_prompt_token_ids_or_embeds
81
from vllm.utils.jsontree import json_map_leaves
82
from vllm.utils.math_utils import cdiv, round_up
83
84
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import DeviceMemoryProfiler
85
from vllm.utils.platform_utils import is_pin_memory_available
86
87
88
89
90
from vllm.utils.torch_utils import (
    get_dtype_size,
    kv_cache_dtype_str_to_dtype,
    supports_dynamo,
)
91
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
92
from vllm.v1.attention.backends.utils import (
93
94
95
    AttentionCGSupport,
    AttentionMetadataBuilder,
    CommonAttentionMetadata,
96
    create_fast_prefill_custom_backend,
97
    get_dcp_local_seq_lens,
98
99
100
    reorder_batch_to_split_decodes_and_prefills,
    split_attn_metadata,
)
101
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from vllm.v1.kv_cache_interface import (
    AttentionSpec,
    ChunkedLocalAttentionSpec,
    CrossAttentionSpec,
    EncoderOnlyAttentionSpec,
    FullAttentionSpec,
    KVCacheConfig,
    KVCacheGroupSpec,
    KVCacheSpec,
    MambaSpec,
    SlidingWindowSpec,
    UniformTypeKVCacheSpecs,
)
from vllm.v1.outputs import (
    EMPTY_MODEL_RUNNER_OUTPUT,
    AsyncModelRunnerOutput,
    DraftTokenIds,
119
    ECConnectorOutput,
120
    KVConnectorOutput,
121
122
123
124
125
    LogprobsLists,
    LogprobsTensors,
    ModelRunnerOutput,
    PoolerOutput,
    SamplerOutput,
126
    make_empty_encoder_model_runner_output,
127
)
128
from vllm.v1.pool.metadata import PoolingMetadata
129
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
130
from vllm.v1.sample.logits_processor.interface import LogitsProcessor
131
from vllm.v1.sample.metadata import SamplingMetadata
132
from vllm.v1.sample.rejection_sampler import RejectionSampler
133
from vllm.v1.sample.sampler import Sampler
134
from vllm.v1.spec_decode.eagle import EagleProposer
135
from vllm.v1.spec_decode.medusa import MedusaProposer
136
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
137
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
138
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
139
from vllm.v1.structured_output.utils import apply_grammar_bitmask
140
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
141
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
142
from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin
143
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
144
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
145
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
146
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
147
148
149
150
151
from vllm.v1.worker.ubatch_utils import (
    UBatchSlice,
    UBatchSlices,
    check_ubatch_thresholds,
)
152
from vllm.v1.worker.utils import is_residual_scattered_for_sp
153

154
155
156
157
158
159
160
161
162
from .utils import (
    AttentionGroup,
    MultiModalBudget,
    add_kv_sharing_layers_to_kv_cache_groups,
    bind_kv_cache,
    gather_mm_placeholders,
    sanity_check_mm_encoder_outputs,
    scatter_mm_placeholders,
)
163

164
if TYPE_CHECKING:
165
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
166
    from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
167
168
169

logger = init_logger(__name__)

170
171
AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata]
# list when ubatching is enabled
172
PerLayerAttnMetadata: TypeAlias = list[AttnMetadataDict] | AttnMetadataDict
173

174

175
176
177
178
179
180
# Wrapper for ModelRunnerOutput to support overlapped execution.
class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
    def __init__(
        self,
        model_runner_output: ModelRunnerOutput,
        sampled_token_ids: torch.Tensor,
181
        logprobs_tensors: torch.Tensor | None,
182
183
        invalid_req_indices: list[int],
        async_output_copy_stream: torch.cuda.Stream,
184
        vocab_size: int,
185
186
187
188
189
    ):
        self._model_runner_output = model_runner_output
        self._invalid_req_indices = invalid_req_indices

        # Event on the copy stream so we can synchronize the non-blocking copy.
190
        self.async_copy_ready_event = torch.Event()
191
192
193
194

        # Keep a reference to the device tensor to avoid it being
        # deallocated until we finish copying it to the host.
        self._sampled_token_ids = sampled_token_ids
195
        self.vocab_size = vocab_size
196
        self._logprobs_tensors = logprobs_tensors
197
198
199
200
201

        # Initiate the copy on a separate stream, but do not synchronize it.
        default_stream = torch.cuda.current_stream()
        with torch.cuda.stream(async_output_copy_stream):
            async_output_copy_stream.wait_stream(default_stream)
202
            self.sampled_token_ids_cpu = self._sampled_token_ids.to(
203
204
                "cpu", non_blocking=True
            )
205
206
207
208
209
            self._logprobs_tensors_cpu = (
                self._logprobs_tensors.to_cpu_nonblocking()
                if self._logprobs_tensors
                else None
            )
210
            self.async_copy_ready_event.record()
211
212
213

    def get_output(self) -> ModelRunnerOutput:
        """Copy the device tensors to the host and return a ModelRunnerOutput.
214

215
216
        This function blocks until the copy is finished.
        """
217
        self.async_copy_ready_event.synchronize()
218

219
220
        # Release the device tensors once the copy has completed.
        del self._logprobs_tensors
221
        del self._sampled_token_ids
222
223
        max_gen_len = self.sampled_token_ids_cpu.shape[-1]
        if max_gen_len == 1:
224
            valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
225
226
227
228
229
        else:
            valid_sampled_token_ids = RejectionSampler.parse_output(
                self.sampled_token_ids_cpu,
                self.vocab_size,
            )
230
        for i in self._invalid_req_indices:
231
            valid_sampled_token_ids[i].clear()
232
233
234

        output = self._model_runner_output
        output.sampled_token_ids = valid_sampled_token_ids
235
236
237
238
        if self._logprobs_tensors_cpu:
            # NOTE(nick): this will need to be updated to use cu_num_accepted_tokens
            # for async sched + spec decode + logprobs compatibility.
            output.logprobs = self._logprobs_tensors_cpu.tolists()
239
240
241
        return output


242
243
244
245
246
247
248
249
250
251
252
class ExecuteModelState(NamedTuple):
    """Ephemeral cached state transferred between execute_model() and
    sample_tokens(), after execute_model() returns None."""

    scheduler_output: "SchedulerOutput"
    logits: torch.Tensor
    spec_decode_metadata: SpecDecodeMetadata | None
    spec_decode_common_attn_metadata: CommonAttentionMetadata | None
    hidden_states: torch.Tensor
    sample_hidden_states: torch.Tensor
    aux_hidden_states: list[torch.Tensor] | None
253
    ec_connector_output: ECConnectorOutput | None
254
255


256
257
258
class GPUModelRunner(
    LoRAModelRunnerMixin, KVConnectorModelRunnerMixin, ECConnectorModelRunnerMixin
):
259
260
    def __init__(
        self,
261
        vllm_config: VllmConfig,
262
        device: torch.device,
263
    ):
264
265
266
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
267
        self.compilation_config = vllm_config.compilation_config
268
269
270
271
272
273
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.observability_config = vllm_config.observability_config
274

275
        from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
276
277

        set_cpu_offload_max_bytes(int(self.cache_config.cpu_offload_gb * 1024**3))
278

279
280
281
282
        model_config = self.model_config
        cache_config = self.cache_config
        scheduler_config = self.scheduler_config
        parallel_config = self.parallel_config
283
        self.device = device
284
285
        self.pin_memory = is_pin_memory_available()
        self.dtype = self.model_config.dtype
286
287
288
        self.kv_cache_dtype = kv_cache_dtype_str_to_dtype(
            cache_config.cache_dtype, self.model_config
        )
289

290
        self.is_pooling_model = model_config.runner_type == "pooling"
291
        self.enable_prompt_embeds = model_config.enable_prompt_embeds
292
        self.is_multimodal_raw_input_only_model = (
293
294
            model_config.is_multimodal_raw_input_only_model
        )
295
296
        # This will be overridden in load_model()
        self.is_multimodal_pruning_enabled = False
297
        self.max_model_len = model_config.max_model_len
298
299
300

        # Always set to false after the first forward pass
        self.calculate_kv_scales = self.cache_config.calculate_kv_scales
301
        self.dcp_world_size = self.parallel_config.decode_context_parallel_size
302
        self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group
303
        self.max_num_tokens = scheduler_config.max_num_batched_tokens
304
        self.max_num_reqs = scheduler_config.max_num_seqs
305

306
307
308
309
310
        # Broadcast PP output for external_launcher (torchrun)
        # to make sure we are synced across pp ranks
        # TODO: Support overlapping mirco-batches
        # https://github.com/vllm-project/vllm/issues/18019
        self.broadcast_pp_output = (
311
312
313
            self.parallel_config.distributed_executor_backend == "external_launcher"
            and len(get_pp_group().ranks) > 0
        )
314

315
        # Model-related.
316
        self.num_query_heads = model_config.get_num_attention_heads(parallel_config)
317
        self.hidden_size = model_config.get_hidden_size()
318
        self.attention_chunk_size = model_config.attention_chunk_size
319
        # Only relevant for models using ALiBi (e.g, MPT)
320
        self.use_alibi = model_config.uses_alibi
321

322
        self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
323

324
        # Multi-modal data support
325
        self.mm_registry = MULTIMODAL_REGISTRY
326
        self.uses_mrope = model_config.uses_mrope
327
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
328
329
            model_config
        )
330

331
332
333
        if self.model_config.is_encoder_decoder:
            # Maximum length of the encoder input, only for encoder-decoder
            # models.
334
            self.max_encoder_len = scheduler_config.max_num_encoder_input_tokens
335
336
337
        else:
            self.max_encoder_len = 0

338
        # Sampler
339
        self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
340

341
        self.eplb_state: EplbState | None = None
342
343
344
345
346
347
        """
        State of the expert parallelism load balancer.

        Will be lazily initialized when the model is loaded.
        """

348
        # Lazy initializations
349
        # self.model: nn.Module  # Set after load_model
350
        # Initialize in initialize_kv_cache
351
        self.kv_caches: list[torch.Tensor] = []
352
353
354
        # Initialize in initialize_kv_cache_tensors
        self.cross_layers_kv_cache: torch.Tensor | None = None
        self.cross_layers_attn_backend: type[AttentionBackend] | None = None
355
356
        # indexes: [kv_cache_group_id][attn_group]
        self.attn_groups: list[list[AttentionGroup]] = []
357
358
        # self.kv_cache_config: KVCacheConfig

359
360
        # mm_hash ->  encoder_output
        self.encoder_cache: dict[str, torch.Tensor] = {}
361

362
        self.use_aux_hidden_state_outputs = False
363
364
365
366
367
        # Set up speculative decoding.
        # NOTE(Jiayi): currently we put the entire draft model on
        # the last PP rank. This is not ideal if there are many
        # layers in the draft model.
        if self.speculative_config and get_pp_group().is_last_rank:
368
369
370
            self.drafter: (
                NgramProposer | SuffixDecodingProposer | EagleProposer | MedusaProposer
            )
371
372
            if self.speculative_config.method == "ngram":
                self.drafter = NgramProposer(self.vllm_config)
373
374
            elif self.speculative_config.method == "suffix":
                self.drafter = SuffixDecodingProposer(self.vllm_config)
375
            elif self.speculative_config.use_eagle():
376
                self.drafter = EagleProposer(self.vllm_config, self.device, self)
377
378
379
380
                if self.speculative_config.method == "eagle3":
                    self.use_aux_hidden_state_outputs = True
            elif self.speculative_config.method == "medusa":
                self.drafter = MedusaProposer(
381
                    vllm_config=self.vllm_config, device=self.device
382
                )
383
            else:
384
385
386
387
                raise ValueError(
                    "Unknown speculative decoding method: "
                    f"{self.speculative_config.method}"
                )
388
            self.rejection_sampler = RejectionSampler(self.sampler)
389

390
391
392
393
        self.num_spec_tokens = 0
        if self.speculative_config:
            self.num_spec_tokens = self.speculative_config.num_speculative_tokens

394
        # Request states.
395
        self.requests: dict[str, CachedRequestState] = {}
396
        self.comm_stream = torch.cuda.Stream()
397

398
399
400
401
402
403
404
405
406
        # Input Batch
        # NOTE(Chen): Ideally, we should initialize the input batch inside
        # `initialize_kv_cache` based on the kv cache config. However, as in
        # https://github.com/vllm-project/vllm/pull/18298, due to some unknown
        # reasons, we have to initialize the input batch before `load_model`,
        # quantization + weight offloading will fail otherwise. As a temporary
        # solution, we initialize the input batch here, and re-initialize it
        # in `initialize_kv_cache` if the block_sizes here is different from
        # the block_sizes in the kv cache config.
407
408
409
410
        logits_processors = model_config.logits_processors
        custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = (
            tuple(logits_processors) if logits_processors is not None else ()
        )
411
412
        self.input_batch = InputBatch(
            max_num_reqs=self.max_num_reqs,
413
414
415
            # We need to use the encoder length for encoder-decoer
            # because of KV cache for cross-attention.
            max_model_len=max(self.max_model_len, self.max_encoder_len),
416
417
418
            max_num_batched_tokens=self.max_num_tokens,
            device=self.device,
            pin_memory=self.pin_memory,
419
            vocab_size=self.model_config.get_vocab_size(),
420
            block_sizes=[self.cache_config.block_size],
421
            kernel_block_sizes=[self.cache_config.block_size],
422
            is_spec_decode=bool(self.vllm_config.speculative_config),
423
            logitsprocs=build_logitsprocs(
424
425
426
                self.vllm_config,
                self.device,
                self.pin_memory,
427
                self.is_pooling_model,
428
                custom_logitsprocs,
429
            ),
430
431
432
            # We currently don't know whether a particular custom logits processor
            # uses output token ids so we set this conservatively.
            logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
433
            is_pooling_model=self.is_pooling_model,
434
            cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
435
        )
436

437
        self.use_async_scheduling = self.scheduler_config.async_scheduling
438
439
440
441
442
        # Separate cuda stream for overlapping transfer of sampled token ids from
        # GPU to CPU when async scheduling is enabled.
        self.async_output_copy_stream: torch.cuda.Stream | None = None
        # cuda event to synchronize use of reused CPU tensors between steps
        # when async scheduling is enabled.
443
        self.prepare_inputs_event: torch.Event | None = None
444
445
        if self.use_async_scheduling:
            self.async_output_copy_stream = torch.cuda.Stream()
446
            self.prepare_inputs_event = torch.Event()
447

448
        # self.cudagraph_batch_sizes sorts in ascending order.
449
450
451
452
        if (
            self.compilation_config.cudagraph_capture_sizes
            and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
        ):
453
454
            self.cudagraph_batch_sizes = sorted(
                self.compilation_config.cudagraph_capture_sizes
455
            )
456

457
        # Cache the device properties.
458
        self._init_device_properties()
459

460
        # Persistent buffers for CUDA graphs.
461
462
463
464
465
        self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32)
        self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64)
        self.query_start_loc = self._make_buffer(
            self.max_num_reqs + 1, dtype=torch.int32
        )
466
        self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
467
468
469
470
        if self.dcp_world_size > 1:
            self.dcp_local_seq_lens = self._make_buffer(
                self.max_num_reqs, dtype=torch.int32
            )
471
472
473
        # Because inputs_embeds may be bfloat16 and we don't need a numpy
        # version of this tensor, avoid a RuntimeError by not creating a
        # numpy buffer.
474
475
476
477
478
479
480
        self.inputs_embeds = self._make_buffer(
            self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False
        )
        self.is_token_ids = self._make_buffer(self.max_num_tokens, dtype=torch.bool)
        self.discard_request_indices = self._make_buffer(
            self.max_num_reqs, dtype=torch.int64
        )
481
482
        self.num_discarded_requests = 0

483
484
485
486
487
488
        self.num_decode_draft_tokens = self._make_buffer(
            self.max_num_reqs, dtype=torch.int32
        )
        self.num_accepted_tokens = self._make_buffer(
            self.max_num_reqs, dtype=torch.int64
        )
489

490
491
        # Only relevant for multimodal models
        if self.supports_mm_inputs:
492
            self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool)
493

494
        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
495
        if self.uses_mrope:
Roger Wang's avatar
Roger Wang committed
496
497
498
499
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
500
501
502
503
504
505

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
506
            self.mrope_positions = self._make_buffer(
507
508
                (3, self.max_num_tokens + 1), dtype=torch.int64
            )
509

510
        # None in the first PP rank. The rest are set after load_model.
511
        self.intermediate_tensors: IntermediateTensors | None = None
512

513
        # OPTIMIZATION: Cache the tensors rather than creating them every step.
514
        # Keep in int64 to avoid overflow with long context
515
516
517
518
        self.arange_np = np.arange(
            max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens),
            dtype=np.int64,
        )
519

520
521
522
523
524
        # Layer pairings for cross-layer KV sharing.
        # If an Attention layer `layer_name` is in the keys of this dict, it
        # means this layer will perform attention using the keys and values
        # from the KV cache of `shared_kv_cache_layers[layer_name]`.
        self.shared_kv_cache_layers: dict[str, str] = {}
525
526
527
528
529
        self.kv_sharing_fast_prefill_eligible_layers: set[str] = set()

        self.kv_sharing_fast_prefill_logits_indices = None
        if self.cache_config.kv_sharing_fast_prefill:
            self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
530
531
                self.max_num_tokens, dtype=torch.int32, device=self.device
            )
532

533
        self.uniform_decode_query_len = 1 + self.num_spec_tokens
534
535
536
537

        # Cudagraph dispatcher for runtime cudagraph dispatching.
        self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)

538
539
540
541
542
543
544
545
546
        self.mm_budget = (
            MultiModalBudget(
                self.model_config,
                self.scheduler_config,
                self.mm_registry,
            )
            if self.supports_mm_inputs
            else None
        )
547

548
        self.reorder_batch_threshold: int | None = None
549

550
551
552
553
554
        # Attention layers that are only in the KVCacheConfig of the runner
        # (e.g., KV sharing, encoder-only attention), but not in the
        # KVCacheConfig of the scheduler.
        self.runner_only_attn_layers: set[str] = set()

555
        # Cached outputs.
556
        self._draft_token_ids: list[list[int]] | torch.Tensor | None = None
557
        self.transfer_event = torch.Event()
558
        self.sampled_token_ids_pinned_cpu = torch.empty(
559
            (self.max_num_reqs, 1),
560
561
            dtype=torch.int64,
            device="cpu",
562
563
            pin_memory=self.pin_memory,
        )
564

565
566
        # Pre-allocated tensor for copying valid sampled token counts to CPU,
        # with dedicated stream for overlapping and event for coordination.
567
        self.valid_sampled_token_count_event: torch.Event | None = None
568
569
        self.valid_sampled_token_count_copy_stream: torch.cuda.Stream | None = None
        if self.use_async_scheduling and self.num_spec_tokens:
570
            self.valid_sampled_token_count_event = torch.Event()
571
572
573
574
575
576
577
578
            self.valid_sampled_token_count_copy_stream = torch.cuda.Stream()
        self.valid_sampled_token_count_cpu = torch.empty(
            self.max_num_reqs,
            dtype=torch.int64,
            device="cpu",
            pin_memory=self.pin_memory,
        )

579
580
        # Ephemeral state transferred between execute_model() and sample_tokens().
        self.execute_model_state: ExecuteModelState | None = None
581
        self.kv_connector_output: KVConnectorOutput | None = None
582

583
584
585
586
    def reset_mm_cache(self) -> None:
        if self.mm_budget:
            self.mm_budget.reset_cache()

587
588
589
590
591
592
593
594
595
596
    def _get_positions(self, num_tokens: Any):
        if isinstance(num_tokens, int):
            if self.uses_mrope:
                return self.mrope_positions.gpu[:, :num_tokens]
            return self.positions.gpu[:num_tokens]
        else:
            if self.uses_mrope:
                return self.mrope_positions.gpu[:, num_tokens]
            return self.positions.gpu[num_tokens]

597
    def _make_buffer(
598
        self, *size: int | torch.SymInt, dtype: torch.dtype, numpy: bool = True
599
600
601
602
603
604
605
606
    ) -> CpuGpuBuffer:
        return CpuGpuBuffer(
            *size,
            dtype=dtype,
            device=self.device,
            pin_memory=self.pin_memory,
            with_numpy=numpy,
        )
607

608
609
610
    def _init_model_kwargs(self, num_tokens: int):
        model_kwargs = dict[str, Any]()

611
        if not self.is_pooling_model:
612
613
            return model_kwargs

614
615
        num_reqs = self.input_batch.num_reqs
        pooling_params = self.input_batch.get_pooling_params()
616
617
618

        token_type_id_requests = dict[int, Any]()
        for i, param in enumerate(pooling_params):
619
620
621
622
623
            if (
                param.extra_kwargs is not None
                and (token_types := param.extra_kwargs.get("compressed_token_type_ids"))
                is not None
            ):
624
625
626
627
628
                token_type_id_requests[i] = token_types

        if len(token_type_id_requests) == 0:
            return model_kwargs

629
        seq_lens = self.seq_lens.gpu[:num_reqs]
630
631
632
633
634
635
636
637
        token_type_ids = []

        for i in range(num_reqs):
            pos = token_type_id_requests.get(i, seq_lens[i])
            ids = (torch.arange(seq_lens[i]) >= pos).int()
            token_type_ids.append(ids)

        model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to(
638
639
            device=self.device
        )
640
641
        return model_kwargs

642
    def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
643
644
        """
        Update the order of requests in the batch based on the attention
645
        backend's needs. For example, some attention backends (namely MLA) may
646
647
648
649
650
651
        want to separate requests based on if the attention computation will be
        compute-bound or memory-bound.

        Args:
            scheduler_output: The scheduler output.
        """
652
653
654
655
656
657
658
659
        # Attention free models have zero kv_cache_goups, however models
        # like Mamba are also attention free but use the kv_cache for
        # keeping its internal state. This is why we check the number
        # of kv_cache groups instead of solely checking
        # for self.model_config.is_attention_free.
        if len(self.kv_cache_config.kv_cache_groups) == 0:
            return

660
661
662
663
        if self.reorder_batch_threshold is not None:
            reorder_batch_to_split_decodes_and_prefills(
                self.input_batch,
                scheduler_output,
664
665
                decode_threshold=self.reorder_batch_threshold,
            )
666

667
668
    # Note: used for model runner override.
    def _init_device_properties(self) -> None:
669
        """Initialize attributes from torch.cuda.get_device_properties"""
670
671
672
673
674
675
676
        self.device_properties = torch.cuda.get_device_properties(self.device)
        self.num_sms = self.device_properties.multi_processor_count

    # Note: used for model runner override.
    def _sync_device(self) -> None:
        torch.cuda.synchronize()

677
    def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
678
679
680
681
682
683
        """Update the cached states and the persistent batch with the scheduler
        output.

        The updated states are used by the `_prepare_inputs` function to create
        the input GPU tensors for the model.

684
685
        The SamplingMetadata is updated and copied to the GPU if there is a
        new/resumed/paused/finished request in the batch.
686
687
        """
        # Remove finished requests from the cached states.
688
689
        for req_id in scheduler_output.finished_req_ids:
            self.requests.pop(req_id, None)
690
691
692
693
694
695
696
        # Remove the finished requests from the persistent batch.
        # NOTE(woosuk): There could be an edge case where finished_req_ids and
        # scheduled_req_ids overlap. This happens when a request is aborted and
        # then resubmitted with the same ID. In this case, we treat them as two
        # distinct requests - clearing the cached states for the first request
        # and handling the second as a new request.
        for req_id in scheduler_output.finished_req_ids:
697
            self.input_batch.remove_request(req_id)
698
699

        # Free the cached encoder outputs.
700
701
        for mm_hash in scheduler_output.free_encoder_mm_hashes:
            self.encoder_cache.pop(mm_hash, None)
702

703
704
705
706
707
708
709
710
711
712
713
714
715
        # Remove the unscheduled requests from the persistent batch.
        # NOTE(woosuk): The unscheduled requests are either preempted requests
        # or running requests that are not scheduled in this step. We remove
        # them from the persistent batch but keep their cached states since
        # they will be scheduled again sometime in the future.
        scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
        cached_req_ids = self.input_batch.req_id_to_index.keys()
        unscheduled_req_ids = cached_req_ids - scheduled_req_ids
        # NOTE(woosuk): The persistent batch optimization assumes that
        # consecutive batches contain mostly the same requests. If batches
        # have low request overlap (e.g., alternating between two distinct
        # sets of requests), this optimization becomes very inefficient.
        for req_id in unscheduled_req_ids:
716
            self.input_batch.remove_request(req_id)
717

718
        reqs_to_add: list[CachedRequestState] = []
719
        # Add new requests to the cached states.
720
721
722
        for new_req_data in scheduler_output.scheduled_new_reqs:
            req_id = new_req_data.req_id
            sampling_params = new_req_data.sampling_params
723
            pooling_params = new_req_data.pooling_params
724

725
726
727
728
            if (
                sampling_params
                and sampling_params.sampling_type == SamplingType.RANDOM_SEED
            ):
729
730
731
732
733
                generator = torch.Generator(device=self.device)
                generator.manual_seed(sampling_params.seed)
            else:
                generator = None

734
735
            if self.is_pooling_model:
                assert pooling_params is not None
736
737
                task = pooling_params.task
                assert task is not None, "You did not set `task` in the API"
738

739
                model = cast(VllmModelForPooling, self.get_model())
740
                to_update = model.pooler.get_pooling_updates(task)
741
742
                to_update.apply(pooling_params)

743
            req_state = CachedRequestState(
744
                req_id=req_id,
745
                prompt_token_ids=new_req_data.prompt_token_ids,
746
                prompt_embeds=new_req_data.prompt_embeds,
747
                mm_features=new_req_data.mm_features,
748
                sampling_params=sampling_params,
749
                pooling_params=pooling_params,
750
                generator=generator,
751
752
                block_ids=new_req_data.block_ids,
                num_computed_tokens=new_req_data.num_computed_tokens,
753
                output_token_ids=[],
754
                lora_request=new_req_data.lora_request,
755
            )
756
757
            self.requests[req_id] = req_state

758
            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
759
            if self.uses_mrope:
760
                self._init_mrope_positions(req_state)
761

762
            reqs_to_add.append(req_state)
763

764
        # Update the states of the running/resumed requests.
765
        is_last_rank = get_pp_group().is_last_rank
766
        req_data = scheduler_output.scheduled_cached_reqs
767
768
769
770
771

        # Wait until valid_sampled_tokens_count is copied to cpu,
        # then use it to update actual num_computed_tokens of each request.
        valid_sampled_token_count = self._get_valid_sampled_token_count()

772
        for i, req_id in enumerate(req_data.req_ids):
773
            req_state = self.requests[req_id]
774
775
            num_computed_tokens = req_data.num_computed_tokens[i]
            new_block_ids = req_data.new_block_ids[i]
776
            resumed_from_preemption = req_id in req_data.resumed_req_ids
777
            num_output_tokens = req_data.num_output_tokens[i]
778
            req_index = self.input_batch.req_id_to_index.get(req_id)
779

780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
            # prev_num_draft_len is used in async scheduling mode with
            # spec decode. it indicates if need to update num_computed_tokens
            # of the request. for example:
            # fist step: num_computed_tokens = 0, spec_tokens = [],
            # prev_num_draft_len = 0.
            # second step: num_computed_tokens = 100(prompt lenth),
            # spec_tokens = [a,b], prev_num_draft_len = 0.
            # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
            # prev_num_draft_len = 2.
            # num_computed_tokens in first step and second step does't contain
            # the spec tokens length, but in third step it contains the
            # spec tokens length. we only need to update num_computed_tokens
            # when prev_num_draft_len > 0.
            if req_state.prev_num_draft_len:
                if req_index is None:
                    req_state.prev_num_draft_len = 0
                else:
                    assert self.input_batch.prev_req_id_to_index is not None
                    prev_req_index = self.input_batch.prev_req_id_to_index[req_id]
                    num_accepted = valid_sampled_token_count[prev_req_index] - 1
                    num_rejected = req_state.prev_num_draft_len - num_accepted
                    num_computed_tokens -= num_rejected
                    req_state.output_token_ids.extend([-1] * num_accepted)
803

804
            # Update the cached states.
805
            req_state.num_computed_tokens = num_computed_tokens
806
807
808
809
810
811
812
813

            if not is_last_rank:
                # When using PP, the scheduler sends the sampled tokens back,
                # because there's no direct communication between the first-
                # stage worker and the last-stage worker.
                new_token_ids = req_data.new_token_ids[i]
                # Add the sampled token(s) from the previous step (if any).
                # This doesn't include "unverified" tokens like spec tokens.
814
815
816
                num_new_tokens = (
                    num_computed_tokens + len(new_token_ids) - req_state.num_tokens
                )
817
818
819
820
                if num_new_tokens == 1:
                    # Avoid slicing list in most common case.
                    req_state.output_token_ids.append(new_token_ids[-1])
                elif num_new_tokens > 0:
821
                    req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:])
822
823
824
825
826
            elif num_output_tokens < len(req_state.output_token_ids):
                # Some output tokens were discarded due to a sync-KV-load
                # failure. Align the cached state.
                del req_state.output_token_ids[num_output_tokens:]
                if req_index is not None:
827
828
829
830
                    end_idx = (
                        self.input_batch.num_prompt_tokens[req_index]
                        + num_output_tokens
                    )
831
832
                    self.input_batch.num_tokens[req_index] = end_idx
                    self.input_batch.num_tokens_no_spec[req_index] = end_idx
833

834
            # Update the block IDs.
835
            if not resumed_from_preemption:
836
837
                if new_block_ids is not None:
                    # Append the new blocks to the existing block IDs.
838
                    for block_ids, new_ids in zip(req_state.block_ids, new_block_ids):
839
                        block_ids.extend(new_ids)
840
            else:
841
                assert req_index is None
842
                assert new_block_ids is not None
843
844
                # The request is resumed from preemption.
                # Replace the existing block IDs with the new ones.
845
                req_state.block_ids = new_block_ids
846
847
848
849
850

            if req_index is None:
                # The request is not in the persistent batch.
                # The request was either preempted and resumed later, or was not
                # scheduled in the previous step and needs to be added again.
851
852
853
854
855
856
857

                if self.use_async_scheduling and num_output_tokens > 0:
                    # We must recover the output token ids for resumed requests in the
                    # async scheduling case, so that correct input_ids are obtained.
                    resumed_token_ids = req_data.all_token_ids[req_id]
                    req_state.output_token_ids = resumed_token_ids[-num_output_tokens:]

858
                reqs_to_add.append(req_state)
859
860
861
                continue

            # Update the persistent batch.
862
            self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens
863
            if new_block_ids is not None:
864
                self.input_batch.block_table.append_row(new_block_ids, req_index)
865
866
867
868
869
870
871

            # For the last rank, we don't need to update the token_ids_cpu
            # because the sampled tokens are already cached.
            if not is_last_rank:
                # Add new_token_ids to token_ids_cpu.
                start_token_index = num_computed_tokens
                end_token_index = num_computed_tokens + len(new_token_ids)
872
                self.input_batch.token_ids_cpu[
873
874
875
                    req_index, start_token_index:end_token_index
                ] = new_token_ids
                self.input_batch.num_tokens_no_spec[req_index] = end_token_index
876
                self.input_batch.num_tokens[req_index] = end_token_index
877

878
            # Add spec_token_ids to token_ids_cpu.
879
            spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
880
                req_id, []
881
            )
882
883
884
885
886
            num_spec_tokens = len(spec_token_ids)
            # For async scheduling, token_ids_cpu assigned from
            # spec_token_ids are placeholders and will be overwritten in
            # _prepare_input_ids.
            if num_spec_tokens:
887
888
889
                start_index = self.input_batch.num_tokens_no_spec[req_index]
                end_token_index = start_index + num_spec_tokens
                self.input_batch.token_ids_cpu[
890
891
                    req_index, start_index:end_token_index
                ] = spec_token_ids
892
893
                # NOTE(woosuk): `num_tokens` here may include spec tokens.
                self.input_batch.num_tokens[req_index] += num_spec_tokens
894
895
896
897
898
899

            # When speculative decoding is used with structured output,
            # the scheduler can drop draft tokens that do not
            # conform to the schema. This can result in
            # scheduler_output.scheduled_spec_decode_tokens being empty,
            # even when speculative decoding is enabled.
900
901
            self.input_batch.spec_token_ids[req_index].clear()
            self.input_batch.spec_token_ids[req_index].extend(spec_token_ids)
902

903
904
905
906
907
908
909
910
911
            # there are no draft tokens with async scheduling,
            # we clear the spec_decoding info in scheduler_output and
            # use normal sampling but rejection_sampling.
            if self.use_async_scheduling:
                req_state.prev_num_draft_len = num_spec_tokens
                if num_spec_tokens and self._draft_token_ids is None:
                    scheduler_output.total_num_scheduled_tokens -= num_spec_tokens
                    scheduler_output.num_scheduled_tokens[req_id] -= num_spec_tokens
                    scheduler_output.scheduled_spec_decode_tokens.pop(req_id, None)
912
913
        # Add the new or resumed requests to the persistent batch.
        # The smaller empty indices are filled first.
914
915
        for request in reqs_to_add:
            self.input_batch.add_request(request)
916

917
918
919
920
921
922
        # Condense the batched states if there are gaps left by removed requests
        self.input_batch.condense()
        # Allow attention backend to reorder the batch, potentially
        self._may_reorder_batch(scheduler_output)
        # Refresh batch metadata with any pending updates.
        self.input_batch.refresh_metadata()
923

924
    def _update_states_after_model_execute(
925
926
        self, output_token_ids: torch.Tensor
    ) -> None:
927
928
929
930
931
932
933
934
935
936
937
938
        """Update the cached states after model execution.

        This is used for MTP/EAGLE for hybrid models, as in linear attention,
        only the last token's state is kept. In MTP/EAGLE, for draft tokens
        the state are kept util we decide how many tokens are accepted for
        each sequence, and a shifting is done during the next iteration
        based on the number of accepted tokens.
        """
        if not self.model_config.is_hybrid or not self.speculative_config:
            return

        # Find the number of accepted tokens for each sequence.
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
        num_accepted_tokens = (
            (
                torch.cat(
                    [
                        output_token_ids,
                        torch.full(
                            (output_token_ids.size(0), 1),
                            -1,
                            device=output_token_ids.device,
                        ),
                    ],
                    dim=1,
                )
                == -1
            )
            .int()
            .argmax(-1)
            .cpu()
            .numpy()
        )
959
960
961
        for i, num_tokens in enumerate(num_accepted_tokens):
            self.input_batch.num_accepted_tokens_cpu[i] = num_tokens

962
    def _init_mrope_positions(self, req_state: CachedRequestState):
963
964
        model = self.get_model()
        assert supports_mrope(model), "M-RoPE support is not implemented."
965
966
967
968
        assert req_state.prompt_token_ids is not None, (
            "M-RoPE requires prompt_token_ids to be available."
        )
        mrope_model = cast(SupportsMRoPE, model)
969
970

        req_state.mrope_positions, req_state.mrope_position_delta = (
971
            mrope_model.get_mrope_input_positions(
972
                req_state.prompt_token_ids,
973
                req_state.mm_features,
974
            )
975
        )
976

977
    def _extract_mm_kwargs(
978
        self,
979
980
        scheduler_output: "SchedulerOutput",
    ) -> BatchedTensorInputs:
981
        if not scheduler_output or not self.is_multimodal_raw_input_only_model:
982
            return {}
983

984
985
        mm_kwargs = list[MultiModalKwargsItem]()
        for req in scheduler_output.scheduled_new_reqs:
986
987
988
            for feature in req.mm_features:
                if feature.data is not None:
                    mm_kwargs.append(feature.data)
989

990
        # Input all modalities at once
991
        model = cast(SupportsMultiModal, self.model)
992
993
        mm_kwargs_combined: BatchedTensorInputs = {}
        for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
994
995
996
997
            mm_kwargs,
            device=self.device,
            pin_memory=self.pin_memory,
            merge_by_field_config=model.merge_by_field_config,
998
            multimodal_cpu_fields=model.multimodal_cpu_fields,
999
1000
        ):
            mm_kwargs_combined.update(mm_kwargs_group)
1001

1002
        return mm_kwargs_combined
1003

1004
    def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs:
1005
        if not self.is_multimodal_raw_input_only_model:
1006
            return {}
1007

1008
1009
1010
1011
1012
        mm_budget = self.mm_budget
        assert mm_budget is not None

        dummy_modality = mm_budget.get_modality_with_max_tokens()
        return self._get_mm_dummy_batch(dummy_modality, num_seqs)
1013

1014
1015
1016
    def _get_cumsum_and_arange(
        self,
        num_tokens: np.ndarray,
1017
        cumsum_dtype: np.dtype | None = None,
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
    ) -> tuple[np.ndarray, np.ndarray]:
        """Get the cumulative sum and batched arange of the given array.
        # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
        # Equivalent to but faster than:
        # np.concatenate([np.arange(n) for n in num_tokens])
        """
        # Step 1. [2, 5, 3] -> [2, 7, 10]
        cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype)
        total_num_tokens = cu_num_tokens[-1]
        # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
        cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens)
        # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        arange = self.arange_np[:total_num_tokens] - cumsums_offsets

        return cu_num_tokens, arange

1034
    def _prepare_input_ids(
1035
1036
1037
1038
        self,
        scheduler_output: "SchedulerOutput",
        total_num_scheduled_tokens: int,
        cu_num_tokens: np.ndarray,
1039
    ) -> None:
1040
        """Prepare the input IDs for the current batch.
1041

1042
1043
1044
1045
1046
1047
1048
        Carefully handles the `prev_sampled_token_ids` which can be cached
        from the previous engine iteration, in which case those tokens on the
        GPU need to be copied into the corresponding slots into input_ids."""

        if self.input_batch.prev_sampled_token_ids is None:
            # Normal scheduling case
            self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
1049
1050
1051
            if self.enable_prompt_embeds:
                self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
                self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
1052
1053
1054
1055
1056
1057
1058
            return

        # Async scheduling case, where some decode requests from the previous
        # iteration won't have entries in input_ids_cpu and need to be copied
        # on the GPU from prev_sampled_token_ids.
        prev_req_id_to_index = self.input_batch.prev_req_id_to_index
        assert prev_req_id_to_index is not None
1059
1060
1061
1062
        sample_flattened_indices: list[int] = []
        spec_flattened_indices: list[int] = []
        prev_common_req_indices: list[int] = []
        prev_draft_token_indices: list[int] = []
1063
1064
        indices_match = True
        max_flattened_index = -1
1065
1066
1067
        total_num_spec_tokens = 0
        scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens

1068
1069
1070
1071
1072
        for req_id, cur_index in self.input_batch.req_id_to_index.items():
            if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
                prev_common_req_indices.append(prev_index)
                # We need to compute the flattened input_ids index of the
                # last token in each common request.
1073
1074
                draft_len = len(scheduled_spec_tokens.get(req_id, ()))
                total_num_spec_tokens += draft_len
1075
                flattened_index = cu_num_tokens[cur_index].item() - 1
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
                # example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2]
                # sample_flattened_indices = [0, 2, 5]
                # spec_flattened_indices = [1,   3, 4,    6, 7]
                sample_flattened_indices.append(flattened_index - draft_len)
                spec_flattened_indices.extend(
                    range(flattened_index - draft_len + 1, flattened_index + 1)
                )
                start = prev_index * self.num_spec_tokens
                # prev_draft_token_indices is used to find which draft_tokens_id
                # should be copied to input_ids
                # example: prev draft_tokens_id [[1,2], [3,4], [5, 6]]
                # flatten draft_tokens_id [1,2,3,4,5,6]
                # draft_len of each request [1, 2, 1]
                # then prev_draft_token_indices is [0,   2, 3,   4]
                prev_draft_token_indices.extend(range(start, start + draft_len))
1091
                indices_match &= prev_index == flattened_index
1092
                max_flattened_index = max(max_flattened_index, flattened_index)
1093
1094
1095
        num_commmon_tokens = len(sample_flattened_indices)
        total_without_spec = total_num_scheduled_tokens - total_num_spec_tokens
        if num_commmon_tokens < total_without_spec:
1096
1097
1098
            # If not all requests are decodes from the last iteration,
            # We need to copy the input_ids_cpu to the GPU first.
            self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
1099
1100
1101
            if self.enable_prompt_embeds:
                self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
                self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
1102
1103
        if num_commmon_tokens == 0:
            # No requests in common with the previous iteration
1104
            # So input_ids.cpu will have all the input ids.
1105
1106
1107
1108
1109
1110
1111
            return
        if indices_match and max_flattened_index == (num_commmon_tokens - 1):
            # Common-case optimization: the batch is unchanged
            # and no reordering happened.
            # The indices are both the same permutation of 0..N-1 so
            # we can copy directly using a single slice.
            self.input_ids.gpu[:num_commmon_tokens].copy_(
1112
1113
1114
                self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, 0],
                non_blocking=True,
            )
1115
1116
            if self.enable_prompt_embeds:
                self.is_token_ids.gpu[:num_commmon_tokens] = True
1117
            return
1118
        # Upload the index tensors asynchronously so the scatter can be non-blocking.
1119
1120
        sampled_tokens_index_tensor = torch.tensor(
            sample_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
1121
        ).to(self.device, non_blocking=True)
1122
        prev_common_req_indices_tensor = torch.tensor(
1123
1124
            prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory
        ).to(self.device, non_blocking=True)
1125
1126
        self.input_ids.gpu.scatter_(
            dim=0,
1127
            index=sampled_tokens_index_tensor,
1128
            src=self.input_batch.prev_sampled_token_ids[
1129
1130
1131
                prev_common_req_indices_tensor, 0
            ],
        )
1132

1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
        # Scatter the draft tokens after the sampled tokens are scattered.
        if self._draft_token_ids is None or not spec_flattened_indices:
            return

        assert isinstance(self._draft_token_ids, torch.Tensor)
        draft_tokens_index_tensor = torch.tensor(
            spec_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
        ).to(self.device, non_blocking=True)
        prev_draft_token_indices_tensor = torch.tensor(
            prev_draft_token_indices, dtype=torch.int64, pin_memory=self.pin_memory
        ).to(self.device, non_blocking=True)

        # because input_ids dtype is torch.int32,
        # so convert draft_token_ids to torch.int32 here.
        draft_token_ids = self._draft_token_ids.to(dtype=torch.int32)
        self._draft_token_ids = None

        self.input_ids.gpu.scatter_(
            dim=0,
            index=draft_tokens_index_tensor,
            src=draft_token_ids.flatten()[prev_draft_token_indices_tensor],
        )

1156
1157
    def _get_encoder_seq_lens(
        self,
1158
        scheduled_encoder_inputs: dict[str, list[int]],
1159
1160
        kv_cache_spec: KVCacheSpec,
        num_reqs: int,
1161
    ) -> np.ndarray | None:
1162
1163
1164
1165
1166
1167
        if not isinstance(kv_cache_spec, CrossAttentionSpec):
            return None

        # Build encoder_seq_lens array mapping request indices to
        # encoder lengths for inputs scheduled in this batch
        encoder_seq_lens = np.zeros(num_reqs, dtype=np.int32)
1168
        for req_id in scheduled_encoder_inputs:
1169
1170
1171
1172
1173
            req_index = self.input_batch.req_id_to_index[req_id]
            encoder_seq_lens[req_index] = self.max_encoder_len

        return encoder_seq_lens

1174
    def _prepare_inputs(
1175
1176
1177
1178
        self,
        scheduler_output: "SchedulerOutput",
        num_scheduled_tokens: np.ndarray,
        max_num_scheduled_tokens: int,
1179
1180
    ) -> tuple[
        torch.Tensor,
1181
1182
1183
        SpecDecodeMetadata | None,
        UBatchSlices | None,
        torch.Tensor | None,
1184
    ]:
1185
1186
        """
        :return: tuple[
1187
            logits_indices, spec_decode_metadata,
1188
            ubatch_slices, num_tokens_across_dp,
1189
1190
        ]
        """
1191
1192
1193
1194
1195
1196
1197
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        assert total_num_scheduled_tokens > 0
        num_reqs = self.input_batch.num_reqs
        assert num_reqs > 0

        # OPTIMIZATION: Start copying the block table first.
        # This way, we can overlap the copy with the following CPU operations.
1198
        self.input_batch.block_table.commit_block_table(num_reqs)
1199
1200
1201

        # Get request indices.
        # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
1202
        req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens)
1203

1204
1205
        # cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
        # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
1206
        cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens)
1207
1208

        # Get positions.
1209
        positions_np = self.positions.np[:total_num_scheduled_tokens]
1210
1211
1212
1213
1214
        np.add(
            self.input_batch.num_computed_tokens_cpu[req_indices],
            arange,
            out=positions_np,
        )
1215

1216
1217
        # Calculate M-RoPE positions.
        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
1218
        if self.uses_mrope:
1219
1220
            self._calc_mrope_positions(scheduler_output)

1221
1222
1223
1224
        # Get token indices.
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
        # where M is the max_model_len.
1225
1226
1227
        token_indices = (
            positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]
        )
1228
        token_indices_tensor = torch.from_numpy(token_indices)
1229

1230
1231
1232
        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
1233
1234
1235
1236
1237
1238
        torch.index_select(
            self.input_batch.token_ids_cpu_tensor.flatten(),
            0,
            token_indices_tensor,
            out=self.input_ids.cpu[:total_num_scheduled_tokens],
        )
1239
        if self.enable_prompt_embeds:
1240
            is_token_ids = self.input_batch.is_token_ids_tensor.flatten()
1241
1242
1243
1244
            torch.index_select(
                is_token_ids,
                0,
                token_indices_tensor,
1245
1246
                out=self.is_token_ids.cpu[:total_num_scheduled_tokens],
            )
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279

        # Because we did not pre-allocate a massive prompt_embeds CPU tensor on
        # the InputBatch, we need to fill in the prompt embeds into the expected
        # spots in the GpuModelRunner's pre-allocated prompt_embeds tensor.
        if self.input_batch.req_prompt_embeds:
            output_idx = 0
            for req_idx in range(num_reqs):
                num_sched = num_scheduled_tokens[req_idx]

                # Skip if this request doesn't have embeddings
                if req_idx not in self.input_batch.req_prompt_embeds:
                    output_idx += num_sched
                    continue

                # Skip if no tokens scheduled
                if num_sched <= 0:
                    output_idx += num_sched
                    continue

                req_embeds = self.input_batch.req_prompt_embeds[req_idx]
                start_pos = self.input_batch.num_computed_tokens_cpu[req_idx]

                # Skip if trying to read beyond available embeddings
                if start_pos >= req_embeds.shape[0]:
                    output_idx += num_sched
                    continue

                # Copy available embeddings
                end_pos = start_pos + num_sched
                actual_end = min(end_pos, req_embeds.shape[0])
                actual_num_sched = actual_end - start_pos

                if actual_num_sched > 0:
1280
1281
1282
                    self.inputs_embeds.cpu[
                        output_idx : output_idx + actual_num_sched
                    ].copy_(req_embeds[start_pos:actual_end])
1283
1284

                output_idx += num_sched
1285

1286
1287
        self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np)
        self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens)
1288
1289

        # Prepare the attention metadata.
1290
        self.query_start_loc.np[0] = 0
1291
        self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens
1292
1293
        # Note: pad query_start_loc to be non-decreasing, as kernels
        # like FlashAttention requires that
1294
        self.query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1])
1295
        self.query_start_loc.copy_to_gpu()
1296
        query_start_loc = self.query_start_loc.gpu[: num_reqs + 1]
1297

1298
        num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
1299
        num_tokens_padded = self._get_num_input_tokens(num_tokens_unpadded)
1300
1301
1302
        uniform_decode = (
            max_num_scheduled_tokens == self.uniform_decode_query_len
        ) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
1303
1304
1305
1306
1307
1308
1309

        # Disable DP padding when running eager to avoid excessive padding when
        # running prefills. This lets us set enforce_eager on the prefiller in
        # a P/D setup and still use CUDA graphs (enabled by this padding) on the
        # decoder.
        allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE

1310
        ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
1311
1312
1313
1314
1315
1316
1317
            num_tokens_unpadded=num_tokens_unpadded,
            parallel_config=self.parallel_config,
            allow_microbatching=True,
            allow_dp_padding=allow_dp_padding,
            num_tokens_padded=num_tokens_padded,
            uniform_decode=uniform_decode,
            num_scheduled_tokens_per_request=num_scheduled_tokens,
1318
        )
1319

1320
        self.seq_lens.np[:num_reqs] = (
1321
1322
            self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens
        )
1323
        # Fill unused with 0 for full cuda graph mode.
1324
1325
        self.seq_lens.np[num_reqs:].fill(0)
        self.seq_lens.copy_to_gpu()
1326

1327
        num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids]
1328
1329
1330
1331
1332
1333
1334
        num_tokens_np = np.array(num_tokens, dtype=np.int32)

        # Record the index of requests that should not be sampled,
        # so that we could clear the sampled tokens before returning
        discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np
        discard_request_indices = np.nonzero(discard_requests_mask)[0]
        self.num_discarded_requests = len(discard_request_indices)
1335
1336
1337
        self.discard_request_indices.np[: self.num_discarded_requests] = (
            discard_request_indices
        )
1338
1339
1340

        self.discard_request_indices.copy_to_gpu(self.num_discarded_requests)

1341
        # Copy the tensors to the GPU.
1342
1343
1344
1345
1346
        self._prepare_input_ids(
            scheduler_output,
            total_num_scheduled_tokens,
            cu_num_tokens,
        )
1347

1348
        if self.uses_mrope:
1349
            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
1350
1351
            self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_(
                self.mrope_positions.cpu[:, :total_num_scheduled_tokens],
1352
1353
                non_blocking=True,
            )
1354
1355
        else:
            # Common case (1D positions)
1356
            self.positions.copy_to_gpu(total_num_scheduled_tokens)
1357

1358
        use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
1359
1360
1361
1362
1363
1364
1365
        if not use_spec_decode:
            # NOTE(woosuk): Due to chunked prefills, the batch may contain
            # partial requests. While we should not sample any token
            # from these partial requests, we do so for simplicity.
            # We will ignore the sampled tokens from the partial requests.
            # TODO: Support prompt logprobs.
            logits_indices = query_start_loc[1:] - 1
1366
            num_draft_tokens = None
1367
            spec_decode_metadata = None
1368
            num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
1369
1370
1371
1372
1373
        else:
            # Get the number of draft tokens for each request.
            # Iterate over the dictionary rather than all requests since not all
            # requests have draft tokens.
            num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
1374
1375
1376
            # For chunked prefills, use -1 as mask rather than 0, as guided
            # decoding may rollback speculative tokens.
            num_decode_draft_tokens = np.full(num_reqs, -1, dtype=np.int32)
1377
1378
1379
1380
            for (
                req_id,
                draft_token_ids,
            ) in scheduler_output.scheduled_spec_decode_tokens.items():
1381
1382
                req_idx = self.input_batch.req_id_to_index[req_id]
                num_draft_tokens[req_idx] = len(draft_token_ids)
1383
1384
1385
1386
1387
1388
1389
1390
                num_decode_draft_tokens[req_idx] = (
                    len(draft_token_ids)
                    if (
                        self.input_batch.num_computed_tokens_cpu[req_idx]
                        >= self.input_batch.num_prompt_tokens[req_idx]
                    )
                    else -1
                )
1391
            spec_decode_metadata = self._calc_spec_decode_metadata(
1392
1393
                num_draft_tokens, cu_num_tokens
            )
1394
            logits_indices = spec_decode_metadata.logits_indices
1395
            num_sampled_tokens = num_draft_tokens + 1
1396
            # For DECODE only cuda graph of some attention backends (e.g., GDN).
1397
            self.num_decode_draft_tokens.np[:num_reqs] = num_decode_draft_tokens
1398
1399
            self.num_decode_draft_tokens.np[num_reqs:].fill(-1)
            self.num_decode_draft_tokens.copy_to_gpu()
1400

1401
1402
1403
1404
1405
        # Hot-Swap lora model
        if self.lora_config:
            assert (
                np.sum(num_sampled_tokens)
                <= self.vllm_config.scheduler_config.max_num_batched_tokens
1406
            )
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
            self.set_active_loras(
                self.input_batch, num_scheduled_tokens, num_sampled_tokens
            )

        return (
            logits_indices,
            spec_decode_metadata,
            ubatch_slices,
            num_tokens_across_dp,
        )

    def _build_attention_metadata(
        self,
        total_num_scheduled_tokens: int,
        max_num_scheduled_tokens: int,
        num_reqs: int,
        ubatch_slices: UBatchSlices | None = None,
        logits_indices: torch.Tensor | None = None,
        use_spec_decode: bool = False,
        for_cudagraph_capture: bool = False,
        scheduled_encoder_inputs: dict[str, list[int]] | None = None,
        cascade_attn_prefix_lens: list[list[int]] | None = None,
    ) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]:
        """
        :return: tuple[attn_metadata, spec_decode_common_attn_metadata]
        """
        logits_indices_padded = None
1434
        num_logits_indices = None
1435
1436
1437
1438
1439
1440
        if logits_indices is not None:
            num_logits_indices = logits_indices.size(0)
            if self.cache_config.kv_sharing_fast_prefill:
                logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
                    logits_indices
                )
1441

1442
1443
1444
1445
1446
1447
        # update seq_lens of decode reqs under DCP.
        if self.dcp_world_size > 1:
            self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens(
                self.seq_lens.cpu[:num_reqs],
                self.dcp_world_size,
                self.dcp_rank,
1448
                self.parallel_config.cp_kv_cache_interleave_size,
1449
1450
1451
            )
            self.dcp_local_seq_lens.copy_to_gpu(num_reqs)

1452
1453
1454
        attn_metadata: PerLayerAttnMetadata = {}
        if ubatch_slices is not None:
            attn_metadata = [dict() for _ in range(len(ubatch_slices))]
1455

1456
1457
        # Used in the below loop
        query_start_loc = self.query_start_loc.gpu[: num_reqs + 1]
1458
        query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs + 1]
1459
        seq_lens = self.seq_lens.gpu[:num_reqs]
1460
        seq_lens_cpu = self.seq_lens.cpu[:num_reqs]
1461
1462
1463
        num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
            :num_reqs
        ]
1464
1465
1466
1467
1468
1469

        dcp_local_seq_lens, dcp_local_seq_lens_cpu = None, None
        if self.dcp_world_size > 1:
            dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs]
            dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[:num_reqs]

1470
        spec_decode_common_attn_metadata = None
1471
1472
1473
1474
1475
1476
1477
1478
1479

        if for_cudagraph_capture:
            # For some attention backends (e.g. FA) with sliding window models we need
            # to make sure the backend see a max_seq_len that is larger to the sliding
            # window size when capturing to make sure the correct kernel is selected.
            max_seq_len = self.max_model_len
        else:
            max_seq_len = self.seq_lens.np[:num_reqs].max().item()

1480
1481
        if use_spec_decode:
            self.num_accepted_tokens.np[:num_reqs] = (
1482
1483
                self.input_batch.num_accepted_tokens_cpu[:num_reqs]
            )
1484
1485
            self.num_accepted_tokens.np[num_reqs:].fill(1)
            self.num_accepted_tokens.copy_to_gpu()
1486

1487
1488
        # Prepare the attention metadata for each KV cache group and make layers
        # in the same group share the same metadata.
1489
        for kv_cache_gid, kv_cache_group in enumerate(
1490
1491
            self.kv_cache_config.kv_cache_groups
        ):
1492
            encoder_seq_lens = self._get_encoder_seq_lens(
1493
1494
1495
                scheduled_encoder_inputs or {},
                kv_cache_group.kv_cache_spec,
                num_reqs,
1496
            )
1497

1498
            if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
1499
1500
1501
1502
1503
                # Encoder-only layers do not have KV cache, so we need to
                # create a dummy block table and slot mapping for them.
                blk_table_tensor = torch.zeros(
                    (num_reqs, 1),
                    dtype=torch.int32,
1504
1505
1506
                    device=self.device,
                )
                slot_mapping = torch.zeros(
1507
                    (total_num_scheduled_tokens,),
1508
1509
1510
                    dtype=torch.int64,
                    device=self.device,
                )
1511
            else:
1512
                blk_table = self.input_batch.block_table[kv_cache_gid]
1513
                blk_table_tensor = blk_table.get_device_tensor(num_reqs)
1514
                slot_mapping = blk_table.slot_mapping.gpu[:total_num_scheduled_tokens]
1515
1516
1517

                # Fill unused with -1. Needed for reshape_and_cache in full cuda
                # graph mode.
1518
                blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1)
1519

1520
            common_attn_metadata = CommonAttentionMetadata(
1521
1522
1523
1524
1525
                query_start_loc=query_start_loc,
                query_start_loc_cpu=query_start_loc_cpu,
                seq_lens=seq_lens,
                seq_lens_cpu=seq_lens_cpu,
                num_computed_tokens_cpu=num_computed_tokens_cpu,
1526
1527
1528
                num_reqs=num_reqs,
                num_actual_tokens=total_num_scheduled_tokens,
                max_query_len=max_num_scheduled_tokens,
1529
                max_seq_len=max_seq_len,
1530
1531
                block_table_tensor=blk_table_tensor,
                slot_mapping=slot_mapping,
1532
                logits_indices_padded=logits_indices_padded,
1533
                num_logits_indices=num_logits_indices,
1534
                causal=True,
1535
                encoder_seq_lens=encoder_seq_lens,
1536
                dcp_local_seq_lens=dcp_local_seq_lens,
1537
                dcp_local_seq_lens_cpu=dcp_local_seq_lens_cpu,
1538
1539
            )

1540
            if self.speculative_config and spec_decode_common_attn_metadata is None:
1541
                if isinstance(self.drafter, EagleProposer):
1542
                    if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
1543
1544
1545
                        spec_decode_common_attn_metadata = common_attn_metadata
                else:
                    spec_decode_common_attn_metadata = common_attn_metadata
1546

1547
1548
1549
1550
1551
1552
            for attn_gid, attn_group in enumerate(self.attn_groups[kv_cache_gid]):
                cascade_attn_prefix_len = (
                    cascade_attn_prefix_lens[kv_cache_gid][attn_gid]
                    if cascade_attn_prefix_lens
                    else 0
                )
1553
                builder = attn_group.get_metadata_builder()
1554

1555
                extra_attn_metadata_args = {}
1556
                if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
1557
                    extra_attn_metadata_args = dict(
1558
1559
1560
1561
                        num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs],
                        num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[
                            :num_reqs
                        ],
1562
1563
                    )

1564
1565
                if ubatch_slices is not None:
                    common_attn_metadata_list = split_attn_metadata(
1566
1567
                        ubatch_slices, common_attn_metadata
                    )
1568
                    for ubid, common_attn_metadata in enumerate(
1569
1570
                        common_attn_metadata_list
                    ):
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
                        builder = attn_group.get_metadata_builder(ubatch_id=ubid)
                        if for_cudagraph_capture:
                            attn_metadata_i = builder.build_for_cudagraph_capture(
                                common_attn_metadata
                            )
                        else:
                            attn_metadata_i = builder.build(
                                common_prefix_len=cascade_attn_prefix_len,
                                common_attn_metadata=common_attn_metadata,
                            )
                        for layer_name in kv_cache_group.layer_names:
1582
1583
1584
1585
                            assert type(attn_metadata) is list
                            attn_metadata[ubid][layer_name] = attn_metadata_i
                else:
                    assert isinstance(attn_metadata, dict)
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
                    if for_cudagraph_capture:
                        attn_metadata_i = builder.build_for_cudagraph_capture(
                            common_attn_metadata
                        )
                    else:
                        attn_metadata_i = builder.build(
                            common_prefix_len=cascade_attn_prefix_len,
                            common_attn_metadata=common_attn_metadata,
                            **extra_attn_metadata_args,
                        )
1596
1597
                    for layer_name in attn_group.layer_names:
                        attn_metadata[layer_name] = attn_metadata_i
1598

1599
        return attn_metadata, spec_decode_common_attn_metadata
1600

1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
    def _compute_cascade_attn_prefix_lens(
        self,
        num_scheduled_tokens: np.ndarray,
        num_common_prefix_blocks: list[int],
    ) -> list[list[int]] | None:
        """
        :return: Optional[cascade_attn_prefix_lens]
            cascade_attn_prefix_lens is 2D: ``[kv_cache_group_id][attn_group_idx]``,
            None if we should not use cascade attention
        """
1611

1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
        use_cascade_attn = False
        num_kv_cache_groups = len(self.kv_cache_config.kv_cache_groups)
        cascade_attn_prefix_lens: list[list[int]] = [
            [] for _ in range(num_kv_cache_groups)
        ]

        for kv_cache_gid in range(num_kv_cache_groups):
            for attn_group in self.attn_groups[kv_cache_gid]:
                if isinstance(attn_group.kv_cache_spec, EncoderOnlyAttentionSpec):
                    cascade_attn_prefix_len = 0
                else:
                    # 0 if cascade attention should not be used
                    cascade_attn_prefix_len = self._compute_cascade_attn_prefix_len(
                        num_scheduled_tokens,
                        num_common_prefix_blocks[kv_cache_gid],
                        attn_group.kv_cache_spec,
                        attn_group.get_metadata_builder(),
                    )
                cascade_attn_prefix_lens[kv_cache_gid].append(cascade_attn_prefix_len)
                use_cascade_attn |= cascade_attn_prefix_len > 0

        return cascade_attn_prefix_lens if use_cascade_attn else None
1634

1635
1636
1637
1638
    def _compute_cascade_attn_prefix_len(
        self,
        num_scheduled_tokens: np.ndarray,
        num_common_prefix_blocks: int,
1639
1640
        kv_cache_spec: KVCacheSpec,
        attn_metadata_builder: AttentionMetadataBuilder,
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
    ) -> int:
        """Compute the length of the common prefix for cascade attention.

        NOTE(woosuk): The common prefix length returned by this function
        represents the length used specifically for cascade attention, not the
        actual number of tokens shared between requests. When cascade attention
        is disabled (use_cascade=False), this function returns 0 even if
        requests share common tokens. Additionally, the common prefix length is
        truncated to a multiple of the block size and may be further truncated
        due to implementation details explained below.

        Args:
            num_scheduled_tokens: Number of tokens scheduled per request.
            num_common_prefix_blocks: Number of shared KV cache blocks.

        Returns:
            int: Length of common prefix in tokens.
        """
1659

1660
        common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
        if common_prefix_len == 0:
            # Common case.
            return 0

        # NOTE(woosuk): Cascade attention uses two attention kernels: one
        # for the common prefix and the other for the rest. For the first
        # kernel, we concatenate all the query tokens (possibly from
        # different requests) and treat them as if they are from the same
        # request. Then, we use bi-directional attention to process the
        # common prefix in the KV cache. Importantly, this means that the
        # first kernel does not do any masking.

        # Consider the following example:
        # Request 1's input query: [D, E, X]
        # Request 1's kv cache: [A, B, C, D, E, X]
        # Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
        # Request 2's input query: [E, Y]
        # Request 2's kv cache: [A, B, C, D, E, Y]
        # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])

        # If we use [A, B, C, D, E] as the common prefix, then the
        # first kernel will compute the bi-directional attention between
        # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
        # However, this is wrong because D in Request 1 should not attend to
        # E in the common prefix (i.e., we need masking).
        # To avoid this, [A, B, C, D] should be the common prefix.
        # That is, the common prefix should be capped by the minimum
        # num_computed_tokens among the requests, and plus one to include
        # the first token of the query.

        # In practice, we use [A, B, C] as the common prefix, instead of
        # [A, B, C, D] (i.e., the common prefix is capped by the minimum
        # num_computed_tokens, without plus one).
        # This is because of an implementation detail: We want to always
        # use two kernels for cascade attention. Let's imagine:
        # Request 3's input query: [D]
        # Request 3's kv cache: [A, B, C, D]
1698
        # Request 3's num_computed_tokens: 3 (i.e., [A, B, C])
1699
1700
1701
1702
1703
1704
1705
        # If we use [A, B, C, D] as the common prefix for Request 1-3,
        # then Request 3 will be processed only by the first kernel,
        # and the second kernel will get an empty input. While this is not
        # a fundamental problem, our current implementation does not support
        # this case.
        num_reqs = len(num_scheduled_tokens)
        common_prefix_len = min(
1706
1707
            common_prefix_len, self.input_batch.num_computed_tokens_cpu[:num_reqs].min()
        )
1708
        # common_prefix_len should be a multiple of the block size.
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
        common_prefix_len = (
            common_prefix_len // kv_cache_spec.block_size * kv_cache_spec.block_size
        )
        use_sliding_window = isinstance(kv_cache_spec, SlidingWindowSpec) or (
            isinstance(kv_cache_spec, FullAttentionSpec)
            and kv_cache_spec.sliding_window is not None
        )
        use_local_attention = isinstance(kv_cache_spec, ChunkedLocalAttentionSpec) or (
            isinstance(kv_cache_spec, FullAttentionSpec)
            and kv_cache_spec.attention_chunk_size is not None
        )
1720
1721
        assert isinstance(kv_cache_spec, AttentionSpec)
        use_cascade = attn_metadata_builder.use_cascade_attention(
1722
1723
1724
            common_prefix_len=common_prefix_len,
            query_lens=num_scheduled_tokens,
            num_query_heads=self.num_query_heads,
1725
            num_kv_heads=kv_cache_spec.num_kv_heads,
1726
            use_alibi=self.use_alibi,
1727
            use_sliding_window=use_sliding_window,
1728
            use_local_attention=use_local_attention,
1729
            num_sms=self.num_sms,
1730
            dcp_world_size=self.dcp_world_size,
1731
1732
1733
        )
        return common_prefix_len if use_cascade else 0

1734
1735
    def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
        mrope_pos_ptr = 0
1736
        for index, req_id in enumerate(self.input_batch.req_ids):
1737
1738
1739
            req = self.requests[req_id]
            assert req.mrope_positions is not None

1740
1741
            num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index]
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
1742
            num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
1743
1744
                req.prompt_token_ids, req.prompt_embeds
            )
1745
1746

            if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
1747
1748
                prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens)
                completion_part_len = max(0, num_scheduled_tokens - prompt_part_len)
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
            else:
                prompt_part_len = num_scheduled_tokens
                completion_part_len = 0

            assert num_scheduled_tokens == prompt_part_len + completion_part_len

            if prompt_part_len > 0:
                # prompt's mrope_positions are pre-computed
                dst_start = mrope_pos_ptr
                dst_end = mrope_pos_ptr + prompt_part_len
                src_start = num_computed_tokens
                src_end = num_computed_tokens + prompt_part_len

1762
1763
1764
                self.mrope_positions.cpu[:, dst_start:dst_end] = req.mrope_positions[
                    :, src_start:src_end
                ]
1765
1766
1767
1768
1769
1770
1771
                mrope_pos_ptr += prompt_part_len

            if completion_part_len > 0:
                # compute completion's mrope_positions on-the-fly
                dst_start = mrope_pos_ptr
                dst_end = mrope_pos_ptr + completion_part_len

1772
                assert req.mrope_position_delta is not None
1773
                MRotaryEmbedding.get_next_input_positions_tensor(
1774
                    out=self.mrope_positions.np,
1775
1776
1777
1778
1779
                    out_offset=dst_start,
                    mrope_position_delta=req.mrope_position_delta,
                    context_len=num_computed_tokens + prompt_part_len,
                    num_new_tokens=completion_part_len,
                )
1780
1781
1782

                mrope_pos_ptr += completion_part_len

1783
1784
    def _calc_spec_decode_metadata(
        self,
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
        num_draft_tokens: np.ndarray,
        cu_num_scheduled_tokens: np.ndarray,
    ) -> SpecDecodeMetadata:
        # Inputs:
        # cu_num_scheduled_tokens:  [  4, 104, 107, 207, 209]
        # num_draft_tokens:         [  3,   0,   2,   0,   1]
        # Outputs:
        # cu_num_draft_tokens:      [  3,   3,   5,   5,   6]
        # logits_indices:           [  0,   1,   2,   3, 103, 104, 105, 106,
        #                            206, 207, 208]
        # target_logits_indices:    [  0,   1,   2,   5,   6,   9]
        # bonus_logits_indices:     [  3,   4,   7,   8,  10]

        # Compute the logits indices.
        # [4, 1, 3, 1, 2]
        num_sampled_tokens = num_draft_tokens + 1
1801
1802
1803
1804

        # Step 1. cu_num_sampled_tokens: [4, 5, 8, 9, 11]
        # arange: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
        cu_num_sampled_tokens, arange = self._get_cumsum_and_arange(
1805
1806
            num_sampled_tokens, cumsum_dtype=np.int32
        )
1807
        # Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
1808
        logits_indices = np.repeat(
1809
1810
            cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens
        )
1811
        # Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
1812
1813
1814
1815
1816
1817
        logits_indices += arange

        # Compute the bonus logits indices.
        bonus_logits_indices = cu_num_sampled_tokens - 1

        # Compute the draft logits indices.
1818
1819
1820
        # cu_num_draft_tokens: [3, 3, 5, 5, 6]
        # arange: [0, 1, 2, 0, 1, 0]
        cu_num_draft_tokens, arange = self._get_cumsum_and_arange(
1821
1822
            num_draft_tokens, cumsum_dtype=np.int32
        )
1823
1824
        # [0, 0, 0, 5, 5, 9]
        target_logits_indices = np.repeat(
1825
1826
            cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens
        )
1827
1828
1829
1830
1831
        # [0, 1, 2, 5, 6, 9]
        target_logits_indices += arange

        # TODO: Optimize the CPU -> GPU copy.
        cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
1832
1833
            self.device, non_blocking=True
        )
1834
1835
1836
        cu_num_sampled_tokens = torch.from_numpy(cu_num_sampled_tokens).to(
            self.device, non_blocking=True
        )
1837
1838
1839
        logits_indices = torch.from_numpy(logits_indices).to(
            self.device, non_blocking=True
        )
1840
        target_logits_indices = torch.from_numpy(target_logits_indices).to(
1841
1842
            self.device, non_blocking=True
        )
1843
        bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
1844
1845
            self.device, non_blocking=True
        )
1846

1847
1848
        # Compute the draft token ids.
        # draft_token_indices:      [  1,   2,   3, 105, 106, 208]
1849
        draft_token_ids = self.input_ids.gpu[logits_indices]
1850
1851
        draft_token_ids = draft_token_ids[target_logits_indices + 1]

1852
        return SpecDecodeMetadata(
1853
1854
1855
            draft_token_ids=draft_token_ids,
            num_draft_tokens=num_draft_tokens.tolist(),
            cu_num_draft_tokens=cu_num_draft_tokens,
1856
            cu_num_sampled_tokens=cu_num_sampled_tokens,
1857
1858
1859
1860
1861
            target_logits_indices=target_logits_indices,
            bonus_logits_indices=bonus_logits_indices,
            logits_indices=logits_indices,
        )

1862
1863
1864
1865
1866
1867
1868
    def _prepare_kv_sharing_fast_prefill(
        self,
        logits_indices: torch.Tensor,
    ) -> torch.Tensor:
        assert self.kv_sharing_fast_prefill_logits_indices is not None
        num_logits = logits_indices.shape[0]
        assert num_logits > 0
1869
        self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_(logits_indices)
1870
1871
1872
1873
1874
        # There might have leftover indices in logits_indices[num_logits:]
        # from previous iterations, whose values may be greater than the
        # batch size in the current iteration. To ensure indices are always
        # valid, we fill the padded indices with the last index.
        self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_(
1875
1876
1877
1878
1879
1880
            logits_indices[-1].item()
        )
        if (
            self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
            and num_logits <= self.cudagraph_batch_sizes[-1]
        ):
1881
1882
1883
1884
1885
            # Use piecewise CUDA graphs.
            # Add padding to the batch size.
            num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits)
        else:
            num_logits_padded = num_logits
1886
1887
1888
        logits_indices_padded = self.kv_sharing_fast_prefill_logits_indices[
            :num_logits_padded
        ]
1889
1890
        return logits_indices_padded

1891
1892
1893
1894
1895
1896
1897
1898
    def _batch_mm_kwargs_from_scheduler(
        self,
        scheduler_output: "SchedulerOutput",
    ) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]:
        """Batch multimodal kwargs from scheduled encoder inputs.

        Args:
            scheduler_output: The scheduler output containing scheduled encoder
1899
                inputs.
1900
1901
1902
1903
1904
1905

        Returns:
            A tuple of (mm_kwargs, req_ids_pos) where:
            - mm_kwargs: List of multimodal kwargs items to be batched
            - mm_hashes_pos: List of (mm_hash, position_info) tuples
        """
1906
1907
        scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
        if not scheduled_encoder_inputs:
1908
            return [], []
1909
        # Batch the multi-modal inputs.
1910
        mm_kwargs = list[MultiModalKwargsItem]()
1911
1912
        # list of tuple (mm_hash, position_info)
        mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
1913
1914
        for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
            req_state = self.requests[req_id]
1915
1916

            for mm_input_id in encoder_input_ids:
1917
                mm_feature = req_state.mm_features[mm_input_id]
1918
1919
                if mm_feature.data is None:
                    continue
1920
1921
1922
                mm_hash = mm_feature.identifier
                mm_kwargs.append(mm_feature.data)
                mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
1923

1924
1925
        return mm_kwargs, mm_hashes_pos

1926
1927
1928
    def _execute_mm_encoder(
        self, scheduler_output: "SchedulerOutput"
    ) -> list[torch.Tensor]:
1929
1930
        # Batch the multi-modal inputs using the helper method.
        mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler(
1931
1932
            scheduler_output
        )
1933
1934

        if not mm_kwargs:
1935
            return []
1936

1937
1938
1939
1940
1941
1942
1943
        # Batch mm inputs as much as we can: if a request in the batch has
        # multiple modalities or a different modality than the previous one,
        # we process it separately to preserve item order.
        # FIXME(ywang96): This is a hacky way to deal with multiple modalities
        # in the same batch while still being able to benefit from batching
        # multimodal inputs. The proper solution should be reordering the
        # encoder outputs.
1944
        model = cast(SupportsMultiModal, self.model)
1945
        encoder_outputs: list[torch.Tensor] = []
1946
        for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
1947
1948
1949
1950
            mm_kwargs,
            device=self.device,
            pin_memory=self.pin_memory,
            merge_by_field_config=model.merge_by_field_config,
1951
            multimodal_cpu_fields=model.multimodal_cpu_fields,
1952
        ):
1953
            curr_group_outputs: list[torch.Tensor] = []
1954
1955

            # EVS-related change.
1956
            # (ekhvedchenia): Temporary hack to limit peak memory usage when
1957
            # processing multimodal data. This solves the issue with scheduler
1958
1959
1960
1961
            # putting too many video samples into a single batch. Scheduler
            # uses pruned vision tokens count to compare it versus compute
            # budget which is incorrect (Either input media size or non-pruned
            # output vision tokens count should be considered)
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
            # TODO(ywang96): Fix memory profiling to take EVS into account and
            # remove this hack.
            if (
                self.is_multimodal_pruning_enabled
                and modality == "video"
                and num_items > 1
            ):
                for video_mm_kwargs_item in filter(
                    lambda item: item.modality == "video", mm_kwargs
                ):
                    _, _, micro_batch_mm_inputs = next(
                        group_mm_kwargs_by_modality(
                            [video_mm_kwargs_item],
                            device=self.device,
                            pin_memory=self.pin_memory,
                            merge_by_field_config=model.merge_by_field_config,
1978
                            multimodal_cpu_fields=model.multimodal_cpu_fields,
1979
                        )
1980
                    )
1981

1982
                    micro_batch_outputs = model.embed_multimodal(
1983
1984
                        **micro_batch_mm_inputs
                    )
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994

                    curr_group_outputs.extend(micro_batch_outputs)
            else:
                # Run the encoder.
                # `curr_group_outputs` is either of the following:
                # 1. A tensor of shape (num_items, feature_size, hidden_size)
                # in case feature_size is fixed across all multimodal items.
                # 2. A list or tuple (length: num_items) of tensors,
                # each of shape (feature_size, hidden_size) in case the feature
                # size is dynamic depending on the input multimodal items.
1995
                curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)  # type: ignore[assignment]
1996

1997
1998
            sanity_check_mm_encoder_outputs(
                curr_group_outputs,
1999
                expected_num_items=num_items,
2000
            )
2001
            encoder_outputs.extend(curr_group_outputs)
2002

2003
2004
2005
        # Cache the encoder outputs by mm_hash
        for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
            self.encoder_cache[mm_hash] = scatter_mm_placeholders(
2006
2007
2008
                output,
                is_embed=pos_info.is_embed,
            )
2009
2010
            logger.debug("Finish execute for mm hash %s", mm_hash)
            self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
2011

2012
2013
        return encoder_outputs

2014
    def _gather_mm_embeddings(
2015
2016
        self,
        scheduler_output: "SchedulerOutput",
2017
        shift_computed_tokens: int = 0,
2018
2019
2020
2021
2022
2023
2024
2025
    ) -> tuple[list[torch.Tensor], torch.Tensor]:
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens

        mm_embeds = list[torch.Tensor]()
        is_mm_embed = self.is_mm_embed.cpu
        is_mm_embed[:total_num_scheduled_tokens] = False

        req_start_idx = 0
2026
        should_sync_mrope_positions = False
2027

2028
        for req_id in self.input_batch.req_ids:
2029
2030
            mm_embeds_req: list[torch.Tensor] = []

2031
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
2032
            req_state = self.requests[req_id]
2033
            num_computed_tokens = req_state.num_computed_tokens + shift_computed_tokens
2034

2035
2036
            for mm_feature in req_state.mm_features:
                pos_info = mm_feature.mm_position
2037
2038
                start_pos = pos_info.offset
                num_encoder_tokens = pos_info.length
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054

                # The encoder output is needed if the two ranges overlap:
                # [num_computed_tokens,
                #  num_computed_tokens + num_scheduled_tokens) and
                # [start_pos, start_pos + num_encoder_tokens)
                if start_pos >= num_computed_tokens + num_scheduled_tokens:
                    # The encoder output is not needed in this step.
                    break
                if start_pos + num_encoder_tokens <= num_computed_tokens:
                    # The encoder output is already processed and stored
                    # in the decoder's KV cache.
                    continue

                start_idx = max(num_computed_tokens - start_pos, 0)
                end_idx = min(
                    num_computed_tokens - start_pos + num_scheduled_tokens,
2055
2056
                    num_encoder_tokens,
                )
2057
                assert start_idx < end_idx
2058

2059
                mm_hash = mm_feature.identifier
2060
                encoder_output = self.encoder_cache.get(mm_hash, None)
2061
                assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
2062
2063
2064
2065

                if (is_embed := pos_info.is_embed) is not None:
                    is_embed = is_embed[start_idx:end_idx]

2066
                req_start_pos = req_start_idx + start_pos - num_computed_tokens
2067
2068
2069
                is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
                    True if is_embed is None else is_embed
                )
2070

2071
2072
2073
2074
                mm_embeds_item = gather_mm_placeholders(
                    encoder_output[start_idx:end_idx],
                    is_embed=is_embed,
                )
2075
2076
2077
                mm_embeds_req.append(mm_embeds_item)

            if self.is_multimodal_pruning_enabled and self.uses_mrope:
2078
                assert req_state.mrope_positions is not None
2079
2080
2081
2082
2083
2084
2085
                should_sync_mrope_positions = True
                mm_embeds_req, new_mrope_positions, new_delta = (
                    self.model.recompute_mrope_positions(
                        input_ids=req_state.prompt_token_ids,
                        multimodal_embeddings=mm_embeds_req,
                        mrope_positions=req_state.mrope_positions,
                        num_computed_tokens=req_state.num_computed_tokens,
2086
2087
                    )
                )
2088
2089
2090
2091
                req_state.mrope_positions.copy_(new_mrope_positions)
                req_state.mrope_position_delta = new_delta

            mm_embeds.extend(mm_embeds_req)
2092
2093
2094
            req_start_idx += num_scheduled_tokens

        is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens)
2095
2096
2097

        if should_sync_mrope_positions:
            self._calc_mrope_positions(scheduler_output)
2098
            self.mrope_positions.copy_to_gpu(total_num_scheduled_tokens)
2099

2100
        return mm_embeds, is_mm_embed
2101

2102
    def get_model(self) -> nn.Module:
2103
        # get raw model out of the cudagraph wrapper.
2104
        if isinstance(self.model, (CUDAGraphWrapper, UBatchWrapper)):
2105
            return self.model.unwrap()
2106
2107
        return self.model

2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
    def get_supported_generation_tasks(self) -> list[GenerationTask]:
        model = self.get_model()
        supported_tasks = list[GenerationTask]()

        if is_text_generation_model(model):
            supported_tasks.append("generate")

        if supports_transcription(model):
            if model.supports_transcription_only:
                return ["transcription"]

            supported_tasks.append("transcription")

        return supported_tasks

2123
2124
2125
2126
2127
    def get_supported_pooling_tasks(self) -> list[PoolingTask]:
        model = self.get_model()
        if not is_pooling_model(model):
            return []

2128
2129
        supported_tasks = list(model.pooler.get_supported_tasks())

2130
        if self.scheduler_config.enable_chunked_prefill:
2131
2132
2133
2134
            if "token_embed" in supported_tasks:
                supported_tasks.remove("token_embed")
            if "token_classify" in supported_tasks:
                supported_tasks.remove("token_classify")
2135

2136
2137
            logger.debug_once(
                "Chunked prefill is not supported with "
2138
2139
                "token_embed and token_classify tasks "
                "which using ALL pooling. "
2140
2141
2142
                "Please turn off chunked prefill by "
                "`--no-enable-chunked-prefill` before using it."
            )
2143
2144
2145
2146
2147

        if "score" in supported_tasks:
            num_labels = getattr(self.model_config.hf_config, "num_labels", 0)
            if num_labels != 1:
                supported_tasks.remove("score")
2148
                logger.debug_once("Score API is only enabled for num_labels == 1.")
2149
2150

        return supported_tasks
2151

2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        tasks = list[SupportedTask]()

        if self.model_config.runner_type == "generate":
            tasks.extend(self.get_supported_generation_tasks())
        if self.model_config.runner_type == "pooling":
            tasks.extend(self.get_supported_pooling_tasks())

        return tuple(tasks)

2162
    def sync_and_slice_intermediate_tensors(
2163
2164
        self,
        num_tokens: int,
2165
        intermediate_tensors: IntermediateTensors | None,
2166
2167
        sync_self: bool,
    ) -> IntermediateTensors:
2168
2169
2170
        assert self.intermediate_tensors is not None

        tp = self.vllm_config.parallel_config.tensor_parallel_size
2171
        is_rs = is_residual_scattered_for_sp(self.vllm_config, num_tokens)
2172
2173
2174
2175
2176
2177

        # When sequence parallelism is enabled, the "residual" tensor is sharded
        # across tensor parallel ranks, so each rank only needs its own slice.
        if sync_self:
            assert intermediate_tensors is not None
            for k, v in intermediate_tensors.items():
2178
                is_scattered = k == "residual" and is_rs
2179
                copy_len = num_tokens // tp if is_scattered else num_tokens
2180
                self.intermediate_tensors[k][:copy_len].copy_(
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
                    v[:copy_len], non_blocking=True
                )

        return IntermediateTensors(
            {
                k: v[: num_tokens // tp]
                if k == "residual" and is_rs
                else v[:num_tokens]
                for k, v in self.intermediate_tensors.items()
            }
        )

    def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None:
2194
2195
2196
2197
2198
2199
2200
        """
        Step for the EPLB (Expert Parallelism Load Balancing) state.
        """
        if not self.parallel_config.enable_eplb:
            return

        assert self.eplb_state is not None
2201
2202
        model = self.get_model()
        assert is_mixture_of_experts(model)
2203
2204
2205
        self.eplb_state.step(
            is_dummy,
            is_profile,
2206
            log_stats=self.parallel_config.eplb_config.log_balancedness,
2207
2208
        )

2209
2210
2211
2212
    # This is where the second ubatch is adjusted to account for the padding.
    # Should be called after attention metadata creation. This just pads
    # the second ubatch slice out to the total number of tokens
    # (num_tokens + padding)
2213
2214
    @staticmethod
    def pad_out_ubatch_slice(ubatch_slices: UBatchSlices, num_total_tokens: int):
2215
2216
2217
2218
2219
2220
        padded_second_ubatch_slice = slice(
            ubatch_slices[1].token_slice.start, num_total_tokens
        )
        ubatch_slices[1] = UBatchSlice(
            padded_second_ubatch_slice, padded_second_ubatch_slice
        )
2221

2222
2223
2224
2225
2226
2227
    def _pool(
        self,
        hidden_states: torch.Tensor,
        num_scheduled_tokens: int,
        num_scheduled_tokens_np: np.ndarray,
    ) -> ModelRunnerOutput:
2228
2229
2230
        assert self.input_batch.num_reqs == len(self.input_batch.pooling_params), (
            "Either all or none of the requests in a batch must be pooling request"
        )
2231

2232
        hidden_states = hidden_states[:num_scheduled_tokens]
2233
        pooling_metadata = self.input_batch.get_pooling_metadata()
2234
2235
2236
2237
        pooling_metadata.build_pooling_cursor(
            num_scheduled_tokens_np.tolist(), device=hidden_states.device
        )
        seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs]
2238

2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
        model = cast(VllmModelForPooling, self.model)
        raw_pooler_output: PoolerOutput = model.pooler(
            hidden_states=hidden_states,
            pooling_metadata=pooling_metadata,
        )
        raw_pooler_output = json_map_leaves(
            lambda x: x.to("cpu", non_blocking=True),
            raw_pooler_output,
        )
        self._sync_device()
2249

2250
        pooler_output: list[torch.Tensor | None] = []
2251
        for raw_output, seq_len, prompt_len in zip(
2252
2253
            raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens
        ):
2254
            output = raw_output if seq_len == prompt_len else None
2255
            pooler_output.append(output)
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265

        return ModelRunnerOutput(
            req_ids=self.input_batch.req_ids,
            req_id_to_index=self.input_batch.req_id_to_index,
            sampled_token_ids=[],
            logprobs=None,
            prompt_logprobs_dict={},
            pooler_output=pooler_output,
        )

2266
    def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
2267
2268
2269
2270
2271
2272
        if (
            self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
            and hasattr(self, "cudagraph_batch_sizes")
            and self.cudagraph_batch_sizes
            and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]
        ):
2273
2274
2275
2276
2277
2278
2279
2280
            # Use CUDA graphs.
            # Add padding to the batch size.
            return self.vllm_config.pad_for_cudagraph(num_scheduled_tokens)

        # Eager mode.
        # Pad tokens to multiple of tensor_parallel_size when
        # enabled collective fusion for SP
        tp_size = self.vllm_config.parallel_config.tensor_parallel_size
2281
2282
2283
2284
        if (
            self.compilation_config.pass_config.enable_sequence_parallelism
            and tp_size > 1
        ):
2285
2286
2287
            return round_up(num_scheduled_tokens, tp_size)
        return num_scheduled_tokens

2288
    def _preprocess(
2289
2290
        self,
        scheduler_output: "SchedulerOutput",
2291
        num_input_tokens: int,  # Padded
2292
        intermediate_tensors: IntermediateTensors | None = None,
2293
    ) -> tuple[
2294
2295
        torch.Tensor | None,
        torch.Tensor | None,
2296
        torch.Tensor,
2297
        IntermediateTensors | None,
2298
        dict[str, Any],
2299
        ECConnectorOutput | None,
2300
    ]:
2301
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
2302
        is_first_rank = get_pp_group().is_first_rank
2303

2304
2305
        # _prepare_inputs may reorder the batch, so we must gather multi
        # modal outputs after that to ensure the correct order
2306
2307
        ec_connector_output = None

2308
2309
        if (
            self.supports_mm_inputs
2310
            and is_first_rank
2311
2312
            and not self.model_config.is_encoder_decoder
        ):
2313
            # Run the multimodal encoder if any.
2314
2315
2316
2317
2318
2319
            with self.maybe_get_ec_connector_output(
                scheduler_output,
                encoder_cache=self.encoder_cache,
            ) as ec_connector_output:
                self._execute_mm_encoder(scheduler_output)
                mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output)
2320

2321
2322
2323
            # NOTE(woosuk): To unify token ids and soft tokens (vision
            # embeddings), we always use embeddings (rather than token ids)
            # as input to the multimodal model, even when the input is text.
2324
            inputs_embeds_scheduled = self.model.embed_input_ids(
2325
2326
2327
                self.input_ids.gpu[:num_scheduled_tokens],
                multimodal_embeddings=mm_embeds,
                is_multimodal=is_mm_embed,
2328
            )
2329

2330
            # TODO(woosuk): Avoid the copy. Optimize.
2331
            self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(inputs_embeds_scheduled)
2332

2333
            input_ids = None
2334
            inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
2335
2336
2337
2338
            model_kwargs = {
                **self._init_model_kwargs(num_scheduled_tokens),
                **self._extract_mm_kwargs(scheduler_output),
            }
2339
        elif self.enable_prompt_embeds and is_first_rank:
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
            # Get the input embeddings for the tokens that are not input embeds,
            # then put them into the appropriate positions.
            # TODO(qthequartermasterman): Since even when prompt embeds are
            # enabled, (a) not all requests will use prompt embeds, and (b)
            # after the initial prompt is processed, the rest of the generated
            # tokens will be token ids, it is not desirable to have the
            # embedding layer outside of the CUDA graph all the time. The v0
            # engine avoids this by "double compiling" the CUDA graph, once
            # with input_ids and again with inputs_embeds, for all num_tokens.
            # If a batch only has token ids, then including the embedding layer
            # in the CUDA graph will be more performant (like in the else case
            # below).
2352
2353
2354
            token_ids_idx = (
                self.is_token_ids.gpu[:num_scheduled_tokens]
                .nonzero(as_tuple=False)
2355
                .squeeze(1)
2356
            )
2357
2358
2359
            # Some tokens ids may need to become embeds
            if token_ids_idx.numel() > 0:
                token_ids = self.input_ids.gpu[token_ids_idx]
2360
                tokens_to_embeds = self.model.embed_input_ids(input_ids=token_ids)
2361
2362
2363
2364
2365
                self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds

            inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
            model_kwargs = self._init_model_kwargs(num_input_tokens)
            input_ids = None
2366
        else:
2367
2368
2369
2370
            # For text-only models, we use token ids as input.
            # While it is possible to use embeddings as input just like the
            # multimodal models, it is not desirable for performance since
            # then the embedding layer is not included in the CUDA graph.
2371
            input_ids = self.input_ids.gpu[:num_input_tokens]
2372
            inputs_embeds = None
2373
            model_kwargs = self._init_model_kwargs(num_input_tokens)
2374
        if self.uses_mrope:
2375
            positions = self.mrope_positions.gpu[:, :num_input_tokens]
2376
        else:
2377
            positions = self.positions.gpu[:num_input_tokens]
2378

2379
        if is_first_rank:
2380
2381
            intermediate_tensors = None
        else:
2382
            assert intermediate_tensors is not None
2383
            intermediate_tensors = self.sync_and_slice_intermediate_tensors(
2384
2385
                num_input_tokens, intermediate_tensors, True
            )
2386

2387
2388
2389
2390
        if (
            self.model_config.is_encoder_decoder
            and scheduler_output.scheduled_encoder_inputs
        ):
2391
2392
2393
2394
2395
2396
2397
            # Run the encoder, just like we do with other multimodal inputs.
            # For an encoder-decoder model, our processing here is a bit
            # simpler, because the outputs are just passed to the decoder.
            # We are not doing any prompt replacement. We also will only
            # ever have a single encoder input.
            encoder_outputs = self._execute_mm_encoder(scheduler_output)
            model_kwargs.update({"encoder_outputs": encoder_outputs})
2398

2399
2400
2401
2402
2403
2404
        return (
            input_ids,
            inputs_embeds,
            positions,
            intermediate_tensors,
            model_kwargs,
2405
            ec_connector_output,
2406
        )
2407

2408
    def _sample(
2409
        self,
2410
2411
        logits: torch.Tensor | None,
        spec_decode_metadata: SpecDecodeMetadata | None,
2412
    ) -> SamplerOutput:
2413
        # Sample the next token and get logprobs if needed.
2414
        sampling_metadata = self.input_batch.sampling_metadata
2415
        if spec_decode_metadata is None:
2416
2417
2418
            # Update output token ids with tokens sampled in last step
            # if async scheduling and required by current sampling params.
            self.input_batch.update_async_output_token_ids()
2419
            return self.sampler(
2420
2421
2422
                logits=logits,
                sampling_metadata=sampling_metadata,
            )
2423

2424
        sampler_output = self.rejection_sampler(
2425
2426
            spec_decode_metadata,
            None,  # draft_probs
2427
            logits,
2428
2429
            sampling_metadata,
        )
2430
        self._update_states_after_model_execute(sampler_output.sampled_token_ids)
2431
2432
2433
        return sampler_output

    def _bookkeeping_sync(
2434
2435
2436
        self,
        scheduler_output: "SchedulerOutput",
        sampler_output: SamplerOutput,
2437
        logits: torch.Tensor | None,
2438
2439
        hidden_states: torch.Tensor,
        num_scheduled_tokens: int,
2440
        spec_decode_metadata: SpecDecodeMetadata | None,
2441
    ) -> tuple[
2442
        dict[str, int],
2443
        LogprobsLists | None,
2444
        list[list[int]],
2445
        dict[str, LogprobsTensors | None],
2446
2447
2448
        list[str],
        dict[str, int],
        list[int],
2449
    ]:
2450
2451
2452
2453
        num_nans_in_logits = {}
        if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
            num_nans_in_logits = self._get_nans_in_logits(logits)

2454
2455
2456
        discard_sampled_tokens_req_indices = self.discard_request_indices.np[
            : self.num_discarded_requests
        ]
2457
2458
2459
2460
        for i in discard_sampled_tokens_req_indices:
            gen = self.input_batch.generators.get(int(i))
            if gen is not None:
                gen.set_offset(gen.get_offset() - 4)
2461

2462
2463
2464
        # Copy some objects so they don't get modified after returning.
        # This is important when using async scheduling.
        req_ids_output_copy = self.input_batch.req_ids.copy()
2465
        req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy()
2466
2467

        num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
2468
        sampled_token_ids = sampler_output.sampled_token_ids
2469
        logprobs_tensors = sampler_output.logprobs_tensors
2470
        invalid_req_indices = []
2471
        cu_num_new_tokens: list[int] | None = None
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
        if not self.use_async_scheduling:
            # Get the valid generated tokens.
            max_gen_len = sampled_token_ids.shape[-1]
            if max_gen_len == 1:
                # No spec decode tokens.
                valid_sampled_token_ids = self._to_list(sampled_token_ids)
            else:
                # Includes spec decode tokens.
                valid_sampled_token_ids = self.rejection_sampler.parse_output(
                    sampled_token_ids,
                    self.input_batch.vocab_size,
                )
2484
2485
2486
2487
2488
2489
                if logprobs_tensors:
                    # Needed for extracting logprobs when spec decoding.
                    # This must be done prior to discarding sampled tokens.
                    cu_num_new_tokens = [0]
                    for toks in valid_sampled_token_ids:
                        cu_num_new_tokens.append(cu_num_new_tokens[-1] + len(toks))
2490
2491
            # Mask out the sampled tokens that should not be sampled.
            for i in discard_sampled_tokens_req_indices:
2492
                valid_sampled_token_ids[int(i)].clear()
2493
        else:
2494
            valid_sampled_token_ids = []
2495
            invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
2496
2497
2498
2499
2500
            invalid_req_indices_set = set(invalid_req_indices)

            # Cache the sampled tokens on the GPU and avoid CPU sync.
            # These will be copied into input_ids in the next step
            # when preparing inputs.
2501
2502
2503
2504
            # With spec decoding, this is done in propose_draft_token_ids().
            if self.input_batch.prev_sampled_token_ids is None:
                assert sampled_token_ids.shape[-1] == 1
                self.input_batch.prev_sampled_token_ids = sampled_token_ids
2505
2506
2507
2508
2509
            self.input_batch.prev_req_id_to_index = {
                req_id: i
                for i, req_id in enumerate(self.input_batch.req_ids)
                if i not in invalid_req_indices_set
            }
2510

2511
2512
2513
2514
2515
        # Cache the sampled tokens in the model runner, so that the scheduler
        # doesn't need to send them back.
        # NOTE(woosuk): As an exception, when using PP, the scheduler sends
        # the sampled tokens back, because there's no direct communication
        # between the first-stage worker and the last-stage worker.
2516
        req_ids = self.input_batch.req_ids
2517
2518
        for req_idx in range(num_sampled_tokens):
            if self.use_async_scheduling:
2519
                sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None
2520
2521
            else:
                sampled_ids = valid_sampled_token_ids[req_idx]
2522

2523
            num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0
2524

2525
            if not sampled_ids:
2526
2527
2528
                continue

            start_idx = self.input_batch.num_tokens_no_spec[req_idx]
2529
            end_idx = start_idx + num_sampled_ids
2530
2531
2532
2533
            assert end_idx <= self.max_model_len, (
                "Sampled token IDs exceed the max model length. "
                f"Total number of tokens: {end_idx} > max_model_len: "
                f"{self.max_model_len}"
2534
            )
2535

2536
2537
            self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids
            self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True
2538
2539
            self.input_batch.num_tokens_no_spec[req_idx] = end_idx
            self.input_batch.num_tokens[req_idx] = end_idx
2540

2541
            req_id = req_ids[req_idx]
2542
2543
2544
            req_state = self.requests[req_id]
            req_state.output_token_ids.extend(sampled_ids)

2545
        logprobs_lists = (
2546
            logprobs_tensors.tolists(cu_num_new_tokens)
2547
            if not self.use_async_scheduling and logprobs_tensors is not None
2548
2549
2550
2551
2552
2553
2554
2555
2556
            else None
        )

        # Compute prompt logprobs if needed.
        prompt_logprobs_dict = self._get_prompt_logprobs_dict(
            hidden_states[:num_scheduled_tokens],
            scheduler_output.num_scheduled_tokens,
        )

2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
        return (
            num_nans_in_logits,
            logprobs_lists,
            valid_sampled_token_ids,
            prompt_logprobs_dict,
            req_ids_output_copy,
            req_id_to_index_output_copy,
            invalid_req_indices,
        )

2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
    @contextmanager
    def synchronize_input_prep(self):
        if self.prepare_inputs_event is None:
            yield
            return

        # Ensure prior step has finished with reused CPU tensors.
        # This is required in the async scheduling case because
        # the CPU->GPU transfer happens async.
        self.prepare_inputs_event.synchronize()
        try:
            yield
        finally:
            self.prepare_inputs_event.record()

2582
2583
    def _model_forward(
        self,
2584
2585
2586
2587
        input_ids: torch.Tensor | None = None,
        positions: torch.Tensor | None = None,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
2588
2589
2590
2591
2592
        **model_kwargs: dict[str, Any],
    ) -> Any:
        """Helper method to call the model forward pass.

        This method can be overridden by subclasses for model execution.
2593
        Motivation: We can inspect only this method versus
2594
2595
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
        the whole execute_model, which has additional logic.

        Args:
            input_ids: Input token IDs
            positions: Token positions
            intermediate_tensors: Tensors from previous pipeline stages
            inputs_embeds: Input embeddings (alternative to input_ids)
            **model_kwargs: Additional model arguments

        Returns:
            Model output tensor
        """
        return self.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
            **model_kwargs,
        )

2614
2615
2616
2617
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
2618
        intermediate_tensors: IntermediateTensors | None = None,
2619
2620
2621
2622
2623
2624
    ) -> ModelRunnerOutput | IntermediateTensors | None:
        if self.execute_model_state is not None:
            raise RuntimeError(
                "State error: sample_tokens() must be called "
                "after execute_model() returns None."
            )
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639

        # self._draft_token_ids is None when `input_fits_in_drafter=False`
        # and there is no draft tokens scheduled. so it need to update the
        # spec_decoding info in scheduler_output with async_scheduling.
        # use deepcopy to avoid the modification has influence on the
        # scheduler_output in engine core process.
        # TODO(Ronald1995): deepcopy is expensive when there is a large
        # number of requests, optimize it later.
        if (
            self.use_async_scheduling
            and self.num_spec_tokens
            and self._draft_token_ids is None
        ):
            scheduler_output = deepcopy(scheduler_output)

2640
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
2641
        with record_function_or_nullcontext("gpu_model_runner: preprocess"):
2642
2643
2644
2645
            with self.synchronize_input_prep():
                # Update persistent batch states.
                self._update_states(scheduler_output)

2646
2647
2648
2649
2650
2651
2652
2653
                if has_ec_transfer() and get_ec_transfer().is_producer:
                    with self.maybe_get_ec_connector_output(
                        scheduler_output,
                        encoder_cache=self.encoder_cache,
                    ) as ec_connector_output:
                        self._execute_mm_encoder(scheduler_output)
                        return make_empty_encoder_model_runner_output(scheduler_output)

2654
                if not num_scheduled_tokens:
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
                    if (
                        self.parallel_config.distributed_executor_backend
                        == "external_launcher"
                        and self.parallel_config.data_parallel_size > 1
                    ):
                        # this is a corner case when both external launcher
                        # and DP are enabled, num_scheduled_tokens could be
                        # 0, and has_unfinished_requests in the outer loop
                        # returns True. before returning early here we call
                        # dummy run to ensure coordinate_batch_across_dp
                        # is called into to avoid out of sync issues.
                        self._dummy_run(1)
2667
2668
2669
2670
                    if not has_kv_transfer_group():
                        # Return empty ModelRunnerOutput if no work to do.
                        return EMPTY_MODEL_RUNNER_OUTPUT
                    return self.kv_connector_no_forward(
2671
2672
                        scheduler_output, self.vllm_config
                    )
2673
2674
2675
2676
                if self.cache_config.kv_sharing_fast_prefill:
                    assert not self.input_batch.num_prompt_logprobs, (
                        "--kv-sharing-fast-prefill produces incorrect "
                        "logprobs for prompt tokens, tokens, please disable "
2677
2678
                        "it when the requests need prompt logprobs"
                    )
2679

2680
2681
2682
2683
2684
2685
                num_reqs = self.input_batch.num_reqs
                req_ids = self.input_batch.req_ids
                tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
                num_scheduled_tokens_np = np.array(tokens, dtype=np.int32)
                max_num_scheduled_tokens = int(num_scheduled_tokens_np.max())

2686
2687
2688
2689
                (
                    logits_indices,
                    spec_decode_metadata,
                    ubatch_slices,
2690
                    num_tokens_across_dp,
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
                ) = self._prepare_inputs(
                    scheduler_output, num_scheduled_tokens_np, max_num_scheduled_tokens
                )

                cascade_attn_prefix_lens = None
                # Disable cascade attention when using microbatching (DBO)
                if self.cascade_attn_enabled and ubatch_slices is None:
                    # Pre-compute cascade attention prefix lengths
                    # NOTE: Must be AFTER _prepare_inputs uses self.input_batch state
                    cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens(
                        num_scheduled_tokens_np,
                        scheduler_output.num_common_prefix_blocks,
                    )

                # TODO(lucas): move cudagraph dispatching here:
                #   https://github.com/vllm-project/vllm/issues/23789

                total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
                use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
                attn_metadata, spec_decode_common_attn_metadata = (
                    self._build_attention_metadata(
                        total_num_scheduled_tokens=total_num_scheduled_tokens,
                        max_num_scheduled_tokens=max_num_scheduled_tokens,
                        num_reqs=num_reqs,
                        ubatch_slices=ubatch_slices,
                        logits_indices=logits_indices,
                        use_spec_decode=use_spec_decode,
                        scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs,
                        cascade_attn_prefix_lens=cascade_attn_prefix_lens,
                    )
                )
2722

2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
                dp_rank = self.parallel_config.data_parallel_rank
                if ubatch_slices:
                    assert num_tokens_across_dp is not None
                    num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
                    self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
                elif num_tokens_across_dp is not None:
                    num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
                else:
                    num_input_tokens = self._get_num_input_tokens(
                        scheduler_output.total_num_scheduled_tokens
                    )
2734

2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
                (
                    input_ids,
                    inputs_embeds,
                    positions,
                    intermediate_tensors,
                    model_kwargs,
                    ec_connector_output,
                ) = self._preprocess(
                    scheduler_output, num_input_tokens, intermediate_tensors
                )
2745

2746
2747
2748
            uniform_decode = (
                max_num_scheduled_tokens == self.uniform_decode_query_len
            ) and (num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
2749
            batch_desc = BatchDescriptor(
2750
2751
2752
                num_tokens=num_input_tokens,
                uniform_decode=uniform_decode,
                has_lora=len(self.input_batch.lora_id_to_lora_request) > 0,
2753
2754
            )
            cudagraph_runtime_mode, batch_descriptor = (
2755
                self.cudagraph_dispatcher.dispatch(
2756
                    batch_desc,
2757
2758
                    use_cascade_attn=cascade_attn_prefix_lens is not None,
                )
2759
            )
2760

2761
        # Set cudagraph mode to none if calc_kv_scales is true.
2762
2763
2764
2765
2766
2767
        # KV scales calculation involves dynamic operations that are incompatible
        # with CUDA graph capture.
        if self.calculate_kv_scales:
            cudagraph_runtime_mode = CUDAGraphMode.NONE
            # Mark KV scales as calculated after the first forward pass
            self.calculate_kv_scales = False
2768

2769
2770
        # Run the model.
        # Use persistent buffers for CUDA graphs.
2771
2772
        with (
            set_forward_context(
2773
2774
2775
2776
2777
2778
                attn_metadata,
                self.vllm_config,
                num_tokens=num_input_tokens,
                num_tokens_across_dp=num_tokens_across_dp,
                cudagraph_runtime_mode=cudagraph_runtime_mode,
                batch_descriptor=batch_descriptor,
2779
                ubatch_slices=ubatch_slices,
2780
            ),
2781
            record_function_or_nullcontext("gpu_model_runner: forward"),
2782
2783
            self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
        ):
2784
            model_output = self._model_forward(
2785
2786
2787
2788
2789
2790
2791
                input_ids=input_ids,
                positions=positions,
                intermediate_tensors=intermediate_tensors,
                inputs_embeds=inputs_embeds,
                **model_kwargs,
            )

2792
        with record_function_or_nullcontext("gpu_model_runner: postprocess"):
2793
            if self.use_aux_hidden_state_outputs:
2794
                # True when EAGLE 3 is used.
2795
2796
                hidden_states, aux_hidden_states = model_output
            else:
2797
                # Common case.
2798
2799
2800
                hidden_states = model_output
                aux_hidden_states = None

2801
2802
2803
2804
2805
            if not self.broadcast_pp_output:
                # Common case.
                if not get_pp_group().is_last_rank:
                    # Return the intermediate tensors.
                    assert isinstance(hidden_states, IntermediateTensors)
2806
                    hidden_states.kv_connector_output = kv_connector_output
2807
                    self.kv_connector_output = kv_connector_output
2808
                    return hidden_states
2809

2810
                if self.is_pooling_model:
2811
                    # Return the pooling output.
2812
2813
2814
                    output = self._pool(
                        hidden_states, num_scheduled_tokens, num_scheduled_tokens_np
                    )
2815
2816
                    output.kv_connector_output = kv_connector_output
                    return output
2817
2818

                sample_hidden_states = hidden_states[logits_indices]
2819
                logits = self.model.compute_logits(sample_hidden_states)
2820
2821
2822
2823
            else:
                # Rare case.
                assert not self.is_pooling_model

2824
                sample_hidden_states = hidden_states[logits_indices]
2825
                if not get_pp_group().is_last_rank:
2826
                    all_gather_tensors = {
2827
2828
2829
                        "residual": not is_residual_scattered_for_sp(
                            self.vllm_config, num_input_tokens
                        )
2830
                    }
2831
                    get_pp_group().send_tensor_dict(
2832
2833
                        hidden_states.tensors,
                        all_gather_group=get_tp_group(),
2834
2835
                        all_gather_tensors=all_gather_tensors,
                    )
2836
2837
                    logits = None
                else:
2838
                    logits = self.model.compute_logits(sample_hidden_states)
2839

2840
                model_output_broadcast_data: dict[str, Any] = {}
2841
2842
2843
                if logits is not None:
                    model_output_broadcast_data["logits"] = logits.contiguous()

2844
                broadcasted = get_pp_group().broadcast_tensor_dict(
2845
2846
                    model_output_broadcast_data, src=len(get_pp_group().ranks) - 1
                )
2847
2848
                assert broadcasted is not None
                logits = broadcasted["logits"]
2849

2850
2851
2852
2853
2854
2855
2856
2857
        self.execute_model_state = ExecuteModelState(
            scheduler_output,
            logits,
            spec_decode_metadata,
            spec_decode_common_attn_metadata,
            hidden_states,
            sample_hidden_states,
            aux_hidden_states,
2858
            ec_connector_output,
2859
        )
2860
        self.kv_connector_output = kv_connector_output
2861
2862
2863
2864
2865
2866
        return None

    @torch.inference_mode
    def sample_tokens(
        self, grammar_output: "GrammarOutput | None"
    ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
2867
2868
2869
        kv_connector_output = self.kv_connector_output
        self.kv_connector_output = None

2870
2871
        if self.execute_model_state is None:
            # Nothing to do (PP non-final rank case), output isn't used.
2872
            if not kv_connector_output:
2873
                return None  # type: ignore[return-value]
2874
2875
2876
2877
2878
2879
2880
2881
2882

            # In case of PP with kv transfer, we need to pass through the
            # kv_connector_output
            if kv_connector_output.is_empty():
                return EMPTY_MODEL_RUNNER_OUTPUT

            output = copy(EMPTY_MODEL_RUNNER_OUTPUT)
            output.kv_connector_output = kv_connector_output
            return output
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892

        # Unpack ephemeral state.
        (
            scheduler_output,
            logits,
            spec_decode_metadata,
            spec_decode_common_attn_metadata,
            hidden_states,
            sample_hidden_states,
            aux_hidden_states,
2893
            ec_connector_output,
2894
2895
2896
2897
2898
2899
2900
2901
2902
        ) = self.execute_model_state
        # Clear ephemeral state.
        self.execute_model_state = None

        # Apply structured output bitmasks if present.
        if grammar_output is not None:
            apply_grammar_bitmask(
                scheduler_output, grammar_output, self.input_batch, logits
            )
2903

2904
        with record_function_or_nullcontext("gpu_model_runner: sample"):
2905
2906
            sampler_output = self._sample(logits, spec_decode_metadata)

2907
2908
        self.input_batch.prev_sampled_token_ids = None

2909
        def propose_draft_token_ids(sampled_token_ids):
2910
            assert spec_decode_common_attn_metadata is not None
2911
            with record_function_or_nullcontext("gpu_model_runner: draft"):
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
                self._draft_token_ids = self.propose_draft_token_ids(
                    scheduler_output,
                    sampled_token_ids,
                    self.input_batch.sampling_metadata,
                    hidden_states,
                    sample_hidden_states,
                    aux_hidden_states,
                    spec_decode_metadata,
                    spec_decode_common_attn_metadata,
                )

2923
        spec_config = self.speculative_config
2924
        use_padded_batch_for_eagle = (
2925
2926
2927
            spec_config is not None
            and spec_config.use_eagle()
            and not spec_config.disable_padded_drafter_batch
2928
        )
2929
2930
2931
        effective_drafter_max_model_len = self.max_model_len
        if effective_drafter_max_model_len is None:
            effective_drafter_max_model_len = self.model_config.max_model_len
2932
        if (
2933
2934
2935
            spec_config is not None
            and spec_config.draft_model_config is not None
            and spec_config.draft_model_config.max_model_len is not None
2936
        ):
2937
            effective_drafter_max_model_len = (
2938
                spec_config.draft_model_config.max_model_len
2939
            )
2940
        input_fits_in_drafter = spec_decode_common_attn_metadata and (
2941
            spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens
2942
2943
            <= effective_drafter_max_model_len
        )
2944
        if use_padded_batch_for_eagle:
2945
2946
            assert self.speculative_config is not None
            assert isinstance(self.drafter, EagleProposer)
2947
2948
2949
2950
2951
2952
            sampled_token_ids = sampler_output.sampled_token_ids
            if input_fits_in_drafter:
                # EAGLE speculative decoding can use the GPU sampled tokens
                # as inputs, and does not need to wait for bookkeeping to finish.
                propose_draft_token_ids(sampled_token_ids)
            elif self.valid_sampled_token_count_event is not None:
2953
                assert spec_decode_common_attn_metadata is not None
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
                next_token_ids, valid_sampled_tokens_count = (
                    self.drafter.prepare_next_token_ids_padded(
                        spec_decode_common_attn_metadata,
                        sampled_token_ids,
                        self.requests,
                        self.input_batch,
                        self.discard_request_indices.gpu,
                        self.num_discarded_requests,
                    )
                )
                self._copy_valid_sampled_token_count(
                    next_token_ids, valid_sampled_tokens_count
                )
2967

2968
        with record_function_or_nullcontext("gpu_model_runner: bookkeep"):
2969
2970
2971
2972
2973
2974
2975
2976
            (
                num_nans_in_logits,
                logprobs_lists,
                valid_sampled_token_ids,
                prompt_logprobs_dict,
                req_ids_output_copy,
                req_id_to_index_output_copy,
                invalid_req_indices,
2977
2978
2979
2980
2981
            ) = self._bookkeeping_sync(
                scheduler_output,
                sampler_output,
                logits,
                hidden_states,
2982
                scheduler_output.total_num_scheduled_tokens,
2983
                spec_decode_metadata,
2984
            )
2985

2986
2987
2988
2989
2990
        if (
            self.speculative_config
            and not use_padded_batch_for_eagle
            and input_fits_in_drafter
        ):
2991
2992
2993
            # ngram and other speculative decoding methods use the sampled
            # tokens on the CPU, so they are run after bookkeeping.
            propose_draft_token_ids(valid_sampled_token_ids)
2994

2995
        with record_function_or_nullcontext("gpu_model_runner: eplb"):
2996
            self.eplb_step()
2997
2998
2999
3000
3001
3002
3003
3004
3005
        with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"):
            output = ModelRunnerOutput(
                req_ids=req_ids_output_copy,
                req_id_to_index=req_id_to_index_output_copy,
                sampled_token_ids=valid_sampled_token_ids,
                logprobs=logprobs_lists,
                prompt_logprobs_dict=prompt_logprobs_dict,
                pooler_output=[],
                kv_connector_output=kv_connector_output,
3006
3007
3008
                ec_connector_output=ec_connector_output
                if self.supports_mm_inputs
                else None,
3009
3010
                num_nans_in_logits=num_nans_in_logits,
            )
3011

3012
3013
        if not self.use_async_scheduling:
            return output
3014
3015
3016
3017
3018
3019
3020
3021
3022
        with record_function_or_nullcontext(
            "gpu_model_runner: AsyncGPUModelRunnerOutput"
        ):
            async_output = AsyncGPUModelRunnerOutput(
                model_runner_output=output,
                sampled_token_ids=sampler_output.sampled_token_ids,
                logprobs_tensors=sampler_output.logprobs_tensors,
                invalid_req_indices=invalid_req_indices,
                async_output_copy_stream=self.async_output_copy_stream,
3023
                vocab_size=self.input_batch.vocab_size,
3024
3025
3026
3027
3028
            )
        with record_function_or_nullcontext(
            "gpu_model_runner: set_async_sampled_token_ids"
        ):
            # Save ref of sampled_token_ids CPU tensor if the batch contains
3029
            # any requests with sampling params that require output ids.
3030
3031
3032
3033
            self.input_batch.set_async_sampled_token_ids(
                async_output.sampled_token_ids_cpu,
                async_output.async_copy_ready_event,
            )
3034
3035
3036

        return async_output

3037
    def take_draft_token_ids(self) -> DraftTokenIds | None:
3038
3039
3040
3041
3042
3043
3044
3045
3046
3047
        if self._draft_token_ids is None:
            return None
        req_ids = self.input_batch.req_ids
        if isinstance(self._draft_token_ids, torch.Tensor):
            draft_token_ids = self._draft_token_ids.tolist()
        else:
            draft_token_ids = self._draft_token_ids
        self._draft_token_ids = None
        return DraftTokenIds(req_ids, draft_token_ids)

3048
3049
3050
3051
3052
3053
3054
3055
3056
3057
3058
3059
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077
3078
    def _copy_valid_sampled_token_count(
        self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor
    ) -> None:
        if self.valid_sampled_token_count_event is None:
            return

        default_stream = torch.cuda.current_stream()
        # Initialize a new stream to overlap the copy operation with
        # prepare_input of draft model.
        with torch.cuda.stream(self.valid_sampled_token_count_copy_stream):
            self.valid_sampled_token_count_copy_stream.wait_stream(default_stream)  # type: ignore
            counts = valid_sampled_tokens_count
            counts_cpu = self.valid_sampled_token_count_cpu
            counts_cpu[: counts.shape[0]].copy_(counts, non_blocking=True)
            self.valid_sampled_token_count_event.record()

        self.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(1)

    def _get_valid_sampled_token_count(self) -> list[int]:
        # Wait until valid_sampled_tokens_count is copied to cpu,
        prev_sampled_token_ids = self.input_batch.prev_sampled_token_ids
        if (
            self.valid_sampled_token_count_event is None
            or prev_sampled_token_ids is None
        ):
            return []

        counts_cpu = self.valid_sampled_token_count_cpu
        self.valid_sampled_token_count_event.synchronize()
        return counts_cpu[: prev_sampled_token_ids.shape[0]].tolist()

3079
3080
3081
    def propose_draft_token_ids(
        self,
        scheduler_output: "SchedulerOutput",
3082
        sampled_token_ids: torch.Tensor | list[list[int]],
3083
3084
3085
        sampling_metadata: SamplingMetadata,
        hidden_states: torch.Tensor,
        sample_hidden_states: torch.Tensor,
3086
3087
        aux_hidden_states: list[torch.Tensor] | None,
        spec_decode_metadata: SpecDecodeMetadata | None,
3088
        common_attn_metadata: CommonAttentionMetadata,
3089
    ) -> list[list[int]] | torch.Tensor:
3090
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
3091
3092
3093
        spec_config = self.speculative_config
        assert spec_config is not None
        if spec_config.method == "ngram":
3094
            assert isinstance(sampled_token_ids, list)
3095
            assert isinstance(self.drafter, NgramProposer)
3096
            draft_token_ids = self.drafter.propose(
3097
3098
                sampled_token_ids,
                self.input_batch.req_ids,
3099
3100
                self.input_batch.num_tokens_no_spec,
                self.input_batch.token_ids_cpu,
3101
3102
                self.input_batch.spec_decode_unsupported_reqs,
            )
3103
        elif spec_config.method == "suffix":
3104
3105
3106
            assert isinstance(sampled_token_ids, list)
            assert isinstance(self.drafter, SuffixDecodingProposer)
            draft_token_ids = self.drafter.propose(self.input_batch, sampled_token_ids)
3107
        elif spec_config.method == "medusa":
3108
            assert isinstance(sampled_token_ids, list)
3109
            assert isinstance(self.drafter, MedusaProposer)
3110

3111
3112
            if sample_hidden_states.shape[0] == len(sampled_token_ids):
                # The input to the target model does not include draft tokens.
3113
3114
3115
3116
                hidden_states = sample_hidden_states
            else:
                indices = []
                offset = 0
3117
3118
3119
                assert spec_decode_metadata is not None, (
                    "No spec decode metadata for medusa"
                )
3120
                for num_draft, tokens in zip(
3121
3122
                    spec_decode_metadata.num_draft_tokens, sampled_token_ids
                ):
3123
                    indices.append(offset + len(tokens) - 1)
3124
                    offset += num_draft + 1
3125
                indices = torch.tensor(indices, device=self.device)
3126
3127
                hidden_states = sample_hidden_states[indices]

3128
            draft_token_ids = self.drafter.propose(
3129
3130
3131
                target_hidden_states=hidden_states,
                sampling_metadata=sampling_metadata,
            )
3132
        elif spec_config.use_eagle():
3133
            assert isinstance(self.drafter, EagleProposer)
3134

3135
            if spec_config.disable_padded_drafter_batch:
3136
3137
3138
                # When padded-batch is disabled, the sampled_token_ids should be
                # the cpu-side list[list[int]] of valid sampled tokens for each
                # request, with invalid requests having empty lists.
3139
3140
                assert isinstance(sampled_token_ids, list), (
                    "sampled_token_ids should be a python list when"
3141
                    "padded-batch is disabled."
3142
                )
3143
                next_token_ids = self.drafter.prepare_next_token_ids_cpu(
3144
3145
3146
3147
3148
                    sampled_token_ids,
                    self.requests,
                    self.input_batch,
                    scheduler_output.num_scheduled_tokens,
                )
3149
3150
3151
3152
3153
            else:
                # When using padded-batch, the sampled_token_ids should be
                # the gpu tensor of sampled tokens for each request, of shape
                # (num_reqs, num_spec_tokens + 1) with rejected tokens having
                # value -1.
3154
3155
                assert isinstance(sampled_token_ids, torch.Tensor), (
                    "sampled_token_ids should be a torch.Tensor when"
3156
                    "padded-batch is enabled."
3157
3158
                )
                next_token_ids, valid_sampled_tokens_count = (
3159
3160
3161
3162
3163
3164
                    self.drafter.prepare_next_token_ids_padded(
                        common_attn_metadata,
                        sampled_token_ids,
                        self.requests,
                        self.input_batch,
                        self.discard_request_indices.gpu,
3165
                        self.num_discarded_requests,
3166
                    )
3167
                )
3168
3169
3170
                self._copy_valid_sampled_token_count(
                    next_token_ids, valid_sampled_tokens_count
                )
Jiayi Yao's avatar
Jiayi Yao committed
3171

3172
            if spec_decode_metadata is None:
3173
                token_indices_to_sample = None
3174
                # input_ids can be None for multimodal models.
3175
                target_token_ids = self.input_ids.gpu[:num_scheduled_tokens]
3176
                target_positions = self._get_positions(num_scheduled_tokens)
3177
                if self.use_aux_hidden_state_outputs:
Wentao Ye's avatar
Wentao Ye committed
3178
                    assert aux_hidden_states is not None
3179
                    target_hidden_states = torch.cat(
3180
3181
                        [h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1
                    )
3182
3183
                else:
                    target_hidden_states = hidden_states[:num_scheduled_tokens]
3184
            else:
3185
                if spec_config.disable_padded_drafter_batch:
3186
                    token_indices_to_sample = None
3187
3188
3189
3190
3191
                    common_attn_metadata, token_indices = self.drafter.prepare_inputs(
                        common_attn_metadata,
                        sampled_token_ids,
                        spec_decode_metadata.num_draft_tokens,
                    )
3192
                else:
3193
                    common_attn_metadata, token_indices, token_indices_to_sample = (
3194
3195
3196
                        self.drafter.prepare_inputs_padded(
                            common_attn_metadata,
                            spec_decode_metadata,
3197
3198
3199
                            valid_sampled_tokens_count,
                        )
                    )
3200

3201
                target_token_ids = self.input_ids.gpu[token_indices]
3202
                target_positions = self._get_positions(token_indices)
3203
                if self.use_aux_hidden_state_outputs:
Wentao Ye's avatar
Wentao Ye committed
3204
                    assert aux_hidden_states is not None
3205
                    target_hidden_states = torch.cat(
3206
3207
                        [h[token_indices] for h in aux_hidden_states], dim=-1
                    )
3208
3209
                else:
                    target_hidden_states = hidden_states[token_indices]
3210

3211
            if self.supports_mm_inputs:
3212
3213
3214
3215
3216
3217
                mm_embed_inputs = self._gather_mm_embeddings(
                    scheduler_output,
                    shift_computed_tokens=1,
                )
            else:
                mm_embed_inputs = None
3218

3219
            draft_token_ids = self.drafter.propose(
3220
3221
3222
3223
                target_token_ids=target_token_ids,
                target_positions=target_positions,
                target_hidden_states=target_hidden_states,
                next_token_ids=next_token_ids,
3224
                last_token_indices=token_indices_to_sample,
3225
                sampling_metadata=sampling_metadata,
3226
                common_attn_metadata=common_attn_metadata,
3227
                mm_embed_inputs=mm_embed_inputs,
3228
            )
3229

3230
        return draft_token_ids
3231

3232
3233
3234
    def update_config(self, overrides: dict[str, Any]) -> None:
        allowed_config_names = {"load_config", "model_config"}
        for config_name, config_overrides in overrides.items():
3235
3236
            assert config_name in allowed_config_names, (
                f"Config `{config_name}` not supported. "
3237
                f"Allowed configs: {allowed_config_names}"
3238
            )
3239
3240
3241
3242
            config = getattr(self, config_name)
            new_config = update_config(config, config_overrides)
            setattr(self, config_name, new_config)

3243
3244
3245
3246
3247
    def load_model(self, eep_scale_up: bool = False) -> None:
        """
        Args:
            eep_scale_up: the model loading is for elastic EP scale up.
        """
3248
3249
3250
3251
3252
        logger.info_once(
            "Starting to load model %s...",
            self.model_config.model,
            scope="global",
        )
3253
3254
3255
3256
3257
        global_expert_loads, old_global_expert_indices_per_model, rank_mapping = (
            EplbState.get_eep_state(self.parallel_config)
            if eep_scale_up
            else (None, None, None)
        )
3258

3259
3260
3261
        if self.parallel_config.enable_eplb:
            self.eplb_state = EplbState(self.parallel_config, self.device)
            eplb_models = 0
3262
        with DeviceMemoryProfiler() as m:
3263
            time_before_load = time.perf_counter()
3264
            model_loader = get_model_loader(self.load_config)
3265
            self.model = model_loader.load_model(
3266
3267
                vllm_config=self.vllm_config, model_config=self.model_config
            )
3268
            if self.lora_config:
3269
3270
3271
                self.model = self.load_lora_model(
                    self.model, self.vllm_config, self.device
                )
3272
            if hasattr(self, "drafter"):
3273
                logger.info_once("Loading drafter model...")
3274
                self.drafter.load_model(self.model)
3275
3276
3277
3278
3279
                if (
                    hasattr(self.drafter, "model")
                    and is_mixture_of_experts(self.drafter.model)
                    and self.parallel_config.enable_eplb
                ):
3280
3281
3282
                    spec_config = self.vllm_config.speculative_config
                    assert spec_config is not None
                    assert spec_config.draft_model_config is not None
3283
3284
                    logger.info_once(
                        "EPLB is enabled for drafter model %s.",
3285
                        spec_config.draft_model_config.model,
3286
3287
3288
3289
3290
3291
3292
3293
3294
3295
3296
3297
3298
3299
3300
3301
                    )

                    global_expert_load = (
                        global_expert_loads[eplb_models]
                        if global_expert_loads
                        else None
                    )
                    old_global_expert_indices = (
                        old_global_expert_indices_per_model[eplb_models]
                        if old_global_expert_indices_per_model
                        else None
                    )
                    if self.eplb_state is None:
                        self.eplb_state = EplbState(self.parallel_config, self.device)
                    self.eplb_state.add_model(
                        self.drafter.model,
3302
                        spec_config.draft_model_config,
3303
3304
3305
3306
3307
3308
                        global_expert_load,
                        old_global_expert_indices,
                        rank_mapping,
                    )
                    eplb_models += 1

3309
            if self.use_aux_hidden_state_outputs:
3310
                if not supports_eagle3(self.get_model()):
3311
3312
                    raise RuntimeError(
                        "Model does not support EAGLE3 interface but "
3313
3314
                        "aux_hidden_state_outputs was requested"
                    )
3315
3316
3317
3318
3319
3320
3321
3322
3323
3324
3325
3326
3327

                # Try to get auxiliary layers from speculative config,
                # otherwise use model's default layers
                aux_layers = self._get_eagle3_aux_layers_from_config()
                if aux_layers:
                    logger.info(
                        "Using auxiliary layers from speculative config: %s",
                        aux_layers,
                    )
                else:
                    aux_layers = self.model.get_eagle3_aux_hidden_state_layers()

                self.model.set_aux_hidden_state_layers(aux_layers)
3328
            time_after_load = time.perf_counter()
3329
        self.model_memory_usage = m.consumed_memory
3330
        logger.info_once(
3331
            "Model loading took %.4f GiB memory and %.6f seconds",
3332
3333
            self.model_memory_usage / GiB_bytes,
            time_after_load - time_before_load,
3334
            scope="local",
3335
        )
3336
        prepare_communication_buffer_for_model(self.model)
3337
        mm_config = self.model_config.multimodal_config
3338
        self.is_multimodal_pruning_enabled = (
3339
            supports_multimodal_pruning(self.get_model())
3340
3341
            and mm_config is not None
            and mm_config.is_multimodal_pruning_enabled()
3342
        )
3343

3344
        if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb:
3345
3346
3347
3348
3349
3350
3351
3352
3353
3354
3355
            logger.info_once("EPLB is enabled for model %s.", self.model_config.model)
            global_expert_load = (
                global_expert_loads[eplb_models] if global_expert_loads else None
            )
            old_global_expert_indices = (
                old_global_expert_indices_per_model[eplb_models]
                if old_global_expert_indices_per_model
                else None
            )
            assert self.eplb_state is not None
            self.eplb_state.add_model(
3356
                self.model,
3357
                self.model_config,
3358
3359
3360
                global_expert_load,
                old_global_expert_indices,
                rank_mapping,
3361
3362
            )

3363
        if (
3364
3365
            self.vllm_config.compilation_config.mode
            == CompilationMode.STOCK_TORCH_COMPILE
3366
            and supports_dynamo()
3367
        ):
3368
            backend = self.vllm_config.compilation_config.init_backend(self.vllm_config)
3369
            compilation_counter.stock_torch_compile_count += 1
3370
            self.model.compile(fullgraph=True, backend=backend)
3371
            return
3372
        # for other compilation modes, cudagraph behavior is controlled by
3373
3374
3375
        # CudagraphWraper and CudagraphDispatcher of vllm.

        # wrap the model with full cudagraph wrapper if needed.
3376
3377
3378
        cudagraph_mode = self.compilation_config.cudagraph_mode
        assert cudagraph_mode is not None
        if cudagraph_mode.has_full_cudagraphs() and not self.parallel_config.enable_dbo:
3379
3380
3381
            self.model = CUDAGraphWrapper(
                self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
            )
3382
        elif self.parallel_config.enable_dbo:
3383
            if cudagraph_mode.has_full_cudagraphs():
3384
3385
3386
                self.model = UBatchWrapper(
                    self.model, self.vllm_config, CUDAGraphMode.FULL, self.device
                )
3387
            else:
3388
3389
3390
                self.model = UBatchWrapper(
                    self.model, self.vllm_config, CUDAGraphMode.NONE, self.device
                )
3391

3392
    def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None:
3393
3394
3395
3396
3397
3398
3399
3400
3401
3402
3403
3404
3405
3406
3407
3408
3409
3410
3411
3412
3413
3414
3415
        """Extract Eagle3 auxiliary layer indices from speculative config.

        These indices specify which hidden states from the base model should
        be used as auxiliary inputs for the Eagle3 drafter model during
        speculative decoding.

        Returns:
            Tuple of layer indices if found in draft model config,
            None otherwise.
        """
        if not (self.speculative_config and self.speculative_config.draft_model_config):
            return None

        hf_config = self.speculative_config.draft_model_config.hf_config
        if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"):
            return None

        layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
        if layer_ids and isinstance(layer_ids, (list, tuple)):
            return tuple(layer_ids)

        return None

3416
    def reload_weights(self) -> None:
3417
        assert getattr(self, "model", None) is not None, (
3418
            "Cannot reload weights before model is loaded."
3419
        )
3420
3421
        model_loader = get_model_loader(self.load_config)
        logger.info("Reloading weights inplace...")
3422
        model_loader.load_weights(self.get_model(), model_config=self.model_config)
3423

3424
3425
3426
3427
3428
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        TensorizerLoader.save_model(
3429
            self.get_model(),
3430
            tensorizer_config=tensorizer_config,
3431
            model_config=self.model_config,
3432
3433
        )

3434
3435
3436
    def _get_prompt_logprobs_dict(
        self,
        hidden_states: torch.Tensor,
3437
        num_scheduled_tokens: dict[str, int],
3438
    ) -> dict[str, LogprobsTensors | None]:
3439
3440
3441
3442
        num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
        if not num_prompt_logprobs_dict:
            return {}

3443
        in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
3444
        prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {}
3445
3446
3447
3448
3449

        # Since prompt logprobs are a rare feature, prioritize simple,
        # maintainable loop over optimal performance.
        completed_prefill_reqs = []
        for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
3450
            num_tokens = num_scheduled_tokens[req_id]
3451
3452
3453

            # Get metadata for this request.
            request = self.requests[req_id]
3454
3455
3456
3457
            if request.prompt_token_ids is None:
                # Prompt logprobs is incompatible with prompt embeddings
                continue

3458
3459
            num_prompt_tokens = len(request.prompt_token_ids)
            prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
3460
3461
                self.device, non_blocking=True
            )
3462

3463
3464
3465
3466
3467
3468
            # Set up target LogprobsTensors object.
            logprobs_tensors = in_progress_dict.get(req_id)
            if not logprobs_tensors:
                # Create empty logprobs CPU tensors for the entire prompt.
                # If chunked, we'll copy in slice by slice.
                logprobs_tensors = LogprobsTensors.empty_cpu(
3469
3470
                    num_prompt_tokens - 1, num_prompt_logprobs + 1
                )
3471
3472
                in_progress_dict[req_id] = logprobs_tensors

3473
            # Determine number of logits to retrieve.
3474
3475
            start_idx = request.num_computed_tokens
            start_tok = start_idx + 1
3476
            num_remaining_tokens = num_prompt_tokens - start_tok
3477
            if num_tokens <= num_remaining_tokens:
3478
                # This is a chunk, more tokens remain.
3479
3480
3481
                # In the == case, there are no more prompt logprobs to produce
                # but we want to defer returning them to the next step where we
                # have new generated tokens to return.
3482
3483
3484
3485
3486
                num_logits = num_tokens
            else:
                # This is the last chunk of prompt tokens to return.
                num_logits = num_remaining_tokens
                completed_prefill_reqs.append(req_id)
3487
3488
3489
3490
3491
3492
3493
                prompt_logprobs_dict[req_id] = logprobs_tensors

            if num_logits <= 0:
                # This can happen for the final chunk if we prefilled exactly
                # (num_prompt_tokens - 1) tokens for this request in the prior
                # step. There are no more prompt logprobs to produce.
                continue
3494
3495
3496
3497
3498

            # Get the logits corresponding to this req's prompt tokens.
            # If this is a partial request (i.e. chunked prefill),
            # then there is prompt logprob generated for each index.
            req_idx = self.input_batch.req_id_to_index[req_id]
3499
            offset = self.query_start_loc.np[req_idx].item()
3500
            prompt_hidden_states = hidden_states[offset : offset + num_logits]
3501
            logits = self.model.compute_logits(prompt_hidden_states)
3502
3503
3504
3505

            # Get the "target" tokens for each index. For prompt at index i,
            # the token at prompt index i+1 is the "sampled" token we want
            # to gather the logprob for.
3506
            tgt_token_ids = prompt_token_ids[start_tok : start_tok + num_logits]
3507
3508

            # Compute prompt logprobs.
3509
3510
            logprobs = self.sampler.compute_logprobs(logits)
            token_ids, logprobs, ranks = self.sampler.gather_logprobs(
3511
3512
                logprobs, num_prompt_logprobs, tgt_token_ids
            )
3513
3514

            # Transfer GPU->CPU async.
3515
3516
            chunk_slice = slice(start_idx, start_idx + num_logits)
            logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
3517
3518
3519
                token_ids, non_blocking=True
            )
            logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, non_blocking=True)
3520
            logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
3521
3522
                ranks, non_blocking=True
            )
3523
3524
3525
3526
3527

        # Remove requests that have completed prefill from the batch
        # num_prompt_logprobs_dict.
        for req_id in completed_prefill_reqs:
            del num_prompt_logprobs_dict[req_id]
3528
            del in_progress_dict[req_id]
3529
3530

        # Must synchronize the non-blocking GPU->CPU transfers.
3531
        if prompt_logprobs_dict:
3532
            self._sync_device()
3533
3534
3535

        return prompt_logprobs_dict

3536
3537
    def _get_nans_in_logits(
        self,
3538
        logits: torch.Tensor | None,
3539
3540
3541
3542
3543
3544
3545
3546
3547
3548
3549
    ) -> dict[str, int]:
        try:
            if logits is None:
                return {req_id: 0 for req_id in self.input_batch.req_ids}

            num_nans_in_logits = {}
            num_nans_for_index = logits.isnan().sum(dim=-1).cpu().numpy()
            for req_id in self.input_batch.req_ids:
                req_index = self.input_batch.req_id_to_index[req_id]
                num_nans_in_logits[req_id] = (
                    int(num_nans_for_index[req_index])
3550
3551
3552
                    if num_nans_for_index is not None and req_index < logits.shape[0]
                    else 0
                )
3553
3554
3555
3556
            return num_nans_in_logits
        except IndexError:
            return {}

3557
3558
3559
3560
3561
3562
    @contextmanager
    def maybe_randomize_inputs(self, input_ids: torch.Tensor):
        """
        Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
        This is to help balance expert-selection
         - during profile_run
3563
         - during DP rank dummy run
3564
3565
3566
3567
3568
3569
3570
3571
3572
3573
3574
        """
        dp_size = self.vllm_config.parallel_config.data_parallel_size
        randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1
        if not randomize_inputs:
            yield
        else:
            import functools

            @functools.cache
            def rand_input_ids() -> torch.Tensor:
                return torch.randint_like(
3575
                    self.input_ids.gpu,
3576
3577
                    low=0,
                    high=self.model_config.get_vocab_size(),
3578
3579
                    dtype=input_ids.dtype,
                )
3580

3581
            logger.debug_once("Randomizing dummy data for DP Rank")
3582
            input_ids.copy_(rand_input_ids()[: input_ids.size(0)], non_blocking=True)
3583
3584
3585
            yield
            input_ids.fill_(0)

3586
3587
3588
3589
3590
3591
    def _get_mm_dummy_batch(
        self,
        modality: str,
        max_items_per_batch: int,
    ) -> BatchedTensorInputs:
        """Dummy data for profiling and precompiling multimodal models."""
3592
3593
        assert self.mm_budget is not None

3594
3595
        dummy_decoder_data = self.mm_registry.get_decoder_dummy_data(
            model_config=self.model_config,
3596
            seq_len=self.max_model_len,
3597
            mm_counts={modality: 1},
3598
            cache=self.mm_budget.cache,
3599
3600
3601
3602
        )
        dummy_mm_data = dummy_decoder_data.multi_modal_data

        # Result in the maximum GPU consumption of the model
3603
3604
        dummy_mm_item = dummy_mm_data[modality][0]
        dummy_mm_items = [dummy_mm_item] * max_items_per_batch
3605

3606
        model = cast(SupportsMultiModal, self.model)
3607
3608
3609
3610
3611
3612
3613
        return next(
            mm_kwargs_group
            for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
                dummy_mm_items,
                device=self.device,
                pin_memory=self.pin_memory,
                merge_by_field_config=model.merge_by_field_config,
3614
                multimodal_cpu_fields=model.multimodal_cpu_fields,
3615
3616
            )
        )
3617

3618
3619
3620
3621
    @torch.inference_mode()
    def _dummy_run(
        self,
        num_tokens: int,
3622
        cudagraph_runtime_mode: CUDAGraphMode | None = None,
3623
3624
        force_attention: bool = False,
        uniform_decode: bool = False,
3625
        allow_microbatching: bool = True,
3626
3627
        skip_eplb: bool = False,
        is_profile: bool = False,
3628
        create_mixed_batch: bool = False,
3629
        remove_lora: bool = True,
3630
        activate_lora: bool = False,
3631
    ) -> tuple[torch.Tensor, torch.Tensor]:
3632
3633
3634
3635
3636
3637
3638
        """
        Run a dummy forward pass to warm up/profile run or capture the
        CUDA graph for the model.

        Args:
            num_tokens: Number of tokens to run the dummy forward pass.
            cudagraph_runtime_mode: used to control the behavior.
3639
                - if not set will determine the cudagraph mode based on using
3640
                    the self.cudagraph_dispatcher.
3641
3642
3643
3644
                - CUDAGraphMode.NONE: No cudagraph, for warm up and profile run
                - CUDAGraphMode.PIECEWISE: Piecewise cudagraph.
                - CUDAGraphMode.FULL: Full cudagraph, attention metadata is
                    needed.
3645
            force_attention: If True, always create attention metadata. Used to
3646
3647
3648
3649
                warm up attention backend when mode is NONE.
            uniform_decode: If True, the batch is a uniform decode batch.
            skip_eplb: If True, skip EPLB state update.
            is_profile: If True, this is a profile run.
3650
3651
            create_mixed_batch: If True, create a mixed batch with both decode
                (1 token) and prefill (multiple tokens) requests.
3652
            remove_lora: If False, dummy LoRAs are not destroyed after the run
3653
            activate_lora: If False, dummy_run is performed without LoRAs.
3654
        """
3655
3656
3657
3658
        assert (
            cudagraph_runtime_mode is None
            or cudagraph_runtime_mode.valid_runtime_modes()
        )
3659

3660
        # If cudagraph_mode.decode_mode() == FULL and
3661
        # cudagraph_mode.separate_routine(). This means that we are using
3662
3663
3664
3665
3666
3667
3668
3669
3670
3671
3672
        # different graphs and/or modes for mixed prefill-decode batches vs.
        # uniform decode batches. A uniform decode batch means that all
        # requests have identical query length, except a potential virtual
        # request (shorter) in the batch account for padding.
        # Uniform decode batch could either be common pure decode, where
        # max_query_len == 1, or speculative decode, where
        # max_query_len == 1 + num_spec_decode_tokens.

        # When setting max_query_len = 1, we switch to and capture the optimized
        # routine of FA2 for pure decode, i.e., Flashdecode + an optimization
        # for GQA/MQA.
3673
        max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens
3674

3675
3676
3677
3678
3679
        # Set num_scheduled_tokens based on num_tokens and max_num_seqs
        # for dummy run with LoRA so that the num_reqs collectively
        # has num_tokens in total.
        assert num_tokens <= self.scheduler_config.max_num_batched_tokens
        max_num_reqs = self.scheduler_config.max_num_seqs
3680
3681
3682
3683
        if create_mixed_batch:
            assert not uniform_decode
            # Create mixed batch:
            # first half decode tokens, second half one prefill
3684
            num_decode_tokens = min(max_num_reqs - 1, num_tokens // 2)
3685
3686
3687
3688
            num_prefill_tokens = num_tokens - num_decode_tokens
            num_reqs = num_decode_tokens + 1

            # Create decode requests (1 token each) followed by prefill request
3689
            num_scheduled_tokens_list = [1] * num_decode_tokens + [num_prefill_tokens]
3690
3691
3692
            # Note: Overriding max_query_len to be the prefill tokens
            max_query_len = num_prefill_tokens
        elif uniform_decode:
3693
            assert not create_mixed_batch
3694
            num_reqs = min(max_num_reqs, cdiv(num_tokens, max_query_len))
3695
3696
            num_scheduled_tokens_list = [max_query_len] * num_reqs
            if num_tokens % max_query_len != 0:
3697
                num_scheduled_tokens_list[-1] = num_tokens % max_query_len
3698
3699
3700
3701
3702
3703
        else:
            num_reqs = min(num_tokens, max_num_reqs)
            min_tokens_per_req = num_tokens // num_reqs
            num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
            num_scheduled_tokens_list[-1] += num_tokens % num_reqs

3704
3705
        assert sum(num_scheduled_tokens_list) == num_tokens
        assert len(num_scheduled_tokens_list) == num_reqs
3706
        num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32)
3707
        total_num_scheduled_tokens = int(num_scheduled_tokens.sum())
3708
        num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
3709

3710
3711
3712
        # Disable DP padding when running eager
        allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE

3713
3714
        # We currently only microbatch if the number of tokens is
        # over a certain threshold.
3715
        ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
3716
3717
3718
3719
3720
3721
3722
            num_tokens_unpadded=total_num_scheduled_tokens,
            parallel_config=self.vllm_config.parallel_config,
            allow_microbatching=allow_microbatching,
            allow_dp_padding=allow_dp_padding,
            num_tokens_padded=total_num_scheduled_tokens,
            uniform_decode=uniform_decode,
            num_scheduled_tokens_per_request=num_scheduled_tokens,
3723
3724
3725
        )
        num_tokens_after_padding = num_tokens
        if num_tokens_across_dp is not None:
3726
3727
            dp_rank = self.parallel_config.data_parallel_rank
            num_tokens_after_padding = int(num_tokens_across_dp[dp_rank])
3728

3729
        attn_metadata: PerLayerAttnMetadata | None = None
3730
3731
3732

        # If force_attention is True, we always capture attention. Otherwise,
        # it only happens for cudagraph_runtime_mode=FULL.
3733
        if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
3734
3735
3736
3737
3738
3739
            if create_mixed_batch:
                # In the mixed batch mode (used for FI warmup), we use
                # shorter sequence lengths to run faster.
                # TODO(luka) better system for describing dummy batches
                seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1]
            else:
3740
                seq_lens = max_query_len  # type: ignore[assignment]
3741
            self.seq_lens.np[:num_reqs] = seq_lens
3742
3743
            self.seq_lens.np[num_reqs:] = 0
            self.seq_lens.copy_to_gpu()
3744

3745
3746
            cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens)
            self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
3747
3748
            self.query_start_loc.copy_to_gpu()

3749
3750
3751
3752
3753
3754
3755
            attn_metadata, _ = self._build_attention_metadata(
                total_num_scheduled_tokens=num_tokens,
                max_num_scheduled_tokens=max_query_len,
                num_reqs=num_reqs,
                ubatch_slices=ubatch_slices,
                for_cudagraph_capture=True,
            )
3756

3757
        with self.maybe_dummy_run_with_lora(
3758
3759
3760
3761
3762
            self.lora_config,
            num_scheduled_tokens,
            num_sampled_tokens,
            activate_lora,
            remove_lora,
3763
        ):
3764
3765
3766
            # Make sure padding doesn't exceed max_num_tokens
            assert num_tokens_after_padding <= self.max_num_tokens
            model_kwargs = self._init_model_kwargs(num_tokens_after_padding)
3767
            if self.supports_mm_inputs and not self.model_config.is_encoder_decoder:
3768
                input_ids = None
3769
                inputs_embeds = self.inputs_embeds.gpu[:num_tokens_after_padding]
3770
                model_kwargs = {
3771
                    **model_kwargs,
3772
3773
                    **self._dummy_mm_kwargs(num_reqs),
                }
3774
3775
            elif self.enable_prompt_embeds:
                input_ids = None
3776
3777
                inputs_embeds = self.inputs_embeds.gpu[:num_tokens_after_padding]
                model_kwargs = self._init_model_kwargs(num_tokens_after_padding)
3778
            else:
3779
                input_ids = self.input_ids.gpu[:num_tokens_after_padding]
3780
                inputs_embeds = None
3781

3782
            if self.uses_mrope:
3783
                positions = self.mrope_positions.gpu[:, :num_tokens_after_padding]
3784
            else:
3785
                positions = self.positions.gpu[:num_tokens_after_padding]
3786
3787
3788
3789
3790
3791
3792
3793
3794

            if get_pp_group().is_first_rank:
                intermediate_tensors = None
            else:
                if self.intermediate_tensors is None:
                    self.intermediate_tensors = (
                        self.model.make_empty_intermediate_tensors(
                            batch_size=self.max_num_tokens,
                            dtype=self.model_config.dtype,
3795
3796
3797
                            device=self.device,
                        )
                    )
3798
3799

                intermediate_tensors = self.sync_and_slice_intermediate_tensors(
3800
                    num_tokens_after_padding, None, False
3801
                )
3802
3803

            # filter out the valid batch descriptor
3804
3805
3806
3807
3808
            _cg_mode, batch_descriptor = (
                self.cudagraph_dispatcher.dispatch(
                    BatchDescriptor(
                        num_tokens=num_tokens_after_padding,
                        uniform_decode=uniform_decode,
3809
                        has_lora=activate_lora and self.lora_config is not None,
3810
3811
3812
3813
3814
                    )
                )
                if not is_profile
                else (CUDAGraphMode.NONE, None)
            )
3815
3816
3817
            if cudagraph_runtime_mode is not None:
                # we allow forcing NONE when the dispatcher disagrees to support
                # warm ups for cudagraph capture
3818
3819
3820
3821
                assert (
                    cudagraph_runtime_mode == CUDAGraphMode.NONE
                    or cudagraph_runtime_mode == _cg_mode
                ), (
3822
                    f"Cudagraph runtime mode mismatch at dummy_run. "
3823
3824
                    f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}."
                )
3825
3826
            else:
                cudagraph_runtime_mode = _cg_mode
3827

3828
            if ubatch_slices is not None:
3829
3830
3831
3832
3833
3834
3835
                # Adjust values to reflect a single ubatch.
                # TODO(sage,lucas): this is cruft that should be addressed in
                #  the padding refactor.
                num_tokens_after_padding = ubatch_slices[0].num_tokens
                if num_tokens_across_dp is not None:
                    num_tokens_across_dp[:] = num_tokens_after_padding

3836
3837
3838
            with (
                self.maybe_randomize_inputs(input_ids),
                set_forward_context(
3839
3840
                    attn_metadata,
                    self.vllm_config,
3841
                    num_tokens=num_tokens_after_padding,
3842
3843
                    num_tokens_across_dp=num_tokens_across_dp,
                    cudagraph_runtime_mode=cudagraph_runtime_mode,
3844
                    batch_descriptor=batch_descriptor,
3845
3846
3847
                    ubatch_slices=ubatch_slices,
                ),
            ):
3848
                outputs = self.model(
3849
3850
3851
3852
                    input_ids=input_ids,
                    positions=positions,
                    intermediate_tensors=intermediate_tensors,
                    inputs_embeds=inputs_embeds,
3853
                    **model_kwargs,
3854
                )
3855

3856
3857
3858
3859
            if self.use_aux_hidden_state_outputs:
                hidden_states, _ = outputs
            else:
                hidden_states = outputs
3860

3861
            if self.speculative_config and self.speculative_config.use_eagle():
3862
                assert isinstance(self.drafter, EagleProposer)
3863
3864
3865
3866
                use_cudagraphs = (
                    cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE
                    and not self.speculative_config.enforce_eager
                )
3867
3868
3869
3870
3871
3872
3873
3874
3875
3876
3877
3878

                # Note(gnovack) - We need to disable cudagraphs for one of the two
                # lora cases when cudagraph_specialize_lora is enabled. This is a
                # short term mitigation for issue mentioned in
                # https://github.com/vllm-project/vllm/issues/28334
                if self.compilation_config.cudagraph_specialize_lora and activate_lora:
                    use_cudagraphs = False

                self.drafter.dummy_run(
                    num_tokens,
                    use_cudagraphs=use_cudagraphs,
                )
3879

3880
3881
3882
3883
3884
3885
3886
3887
3888
3889
        # This is necessary to avoid blocking DP.
        # For dummy runs, we typically skip EPLB since we don't have any real
        # requests to process.
        # However, in DP settings, there may be cases when some DP ranks do
        # not have any requests to process, so they're executing dummy batches.
        # In such cases, we still have to trigger EPLB to make sure
        # ranks execute the rearrangement in synchronization.
        if not skip_eplb:
            self.eplb_step(is_dummy=True, is_profile=is_profile)

3890
        logit_indices = np.cumsum(num_scheduled_tokens) - 1
3891
3892
3893
3894
        logit_indices_device = torch.from_numpy(logit_indices).to(
            self.device, non_blocking=True
        )
        return hidden_states, hidden_states[logit_indices_device]
3895
3896
3897
3898
3899
3900

    @torch.inference_mode()
    def _dummy_sampler_run(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
3901
3902
3903
3904
        # The dummy hidden states may contain special values,
        # like `inf` or `nan`.
        # To avoid breaking the sampler, we use a random tensor here instead.
        hidden_states = torch.rand_like(hidden_states)
3905

3906
        logits = self.model.compute_logits(hidden_states)
3907
3908
        num_reqs = logits.size(0)

3909
        dummy_tensors = lambda v: torch.full((num_reqs,), v, device=self.device)
3910
3911
3912
3913
3914
3915
3916
3917
3918
3919
3920
3921
3922
3923
3924

        dummy_metadata = SamplingMetadata(
            temperature=dummy_tensors(0.5),
            all_greedy=False,
            all_random=False,
            top_p=dummy_tensors(0.9),
            top_k=dummy_tensors(logits.size(1) - 1),
            generators={},
            max_num_logprobs=None,
            no_penalties=True,
            prompt_token_ids=None,
            frequency_penalties=dummy_tensors(0.1),
            presence_penalties=dummy_tensors(0.1),
            repetition_penalties=dummy_tensors(0.1),
            output_token_ids=[[] for _ in range(num_reqs)],
3925
            spec_token_ids=[[] for _ in range(num_reqs)],
3926
3927
            allowed_token_ids_mask=None,
            bad_words_token_ids={},
3928
            logitsprocs=LogitsProcessors(),
3929
        )
3930
        try:
3931
3932
3933
            sampler_output = self.sampler(
                logits=logits, sampling_metadata=dummy_metadata
            )
3934
        except RuntimeError as e:
3935
            if "out of memory" in str(e):
3936
3937
3938
3939
                raise RuntimeError(
                    "CUDA out of memory occurred when warming up sampler with "
                    f"{num_reqs} dummy requests. Please try lowering "
                    "`max_num_seqs` or `gpu_memory_utilization` when "
3940
3941
                    "initializing the engine."
                ) from e
3942
3943
            else:
                raise e
3944
        if self.speculative_config:
3945
3946
            draft_token_ids = [[0] for _ in range(num_reqs)]
            dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy(
3947
3948
                draft_token_ids, self.device
            )
3949
3950
3951
3952
3953
3954

            num_tokens = sum(len(ids) for ids in draft_token_ids)
            # draft_probs = torch.randn(
            #     num_tokens, logits.shape[-1], device=self.device,
            #     dtype=logits.dtype)
            draft_probs = None
3955
3956
3957
3958
3959
            logits = torch.randn(
                num_tokens + num_reqs,
                logits.shape[-1],
                device=self.device,
                dtype=logits.dtype,
3960
            )
3961
3962
3963
            self.rejection_sampler(
                dummy_spec_decode_metadata,
                draft_probs,
3964
                logits,
3965
3966
                dummy_metadata,
            )
3967
        return sampler_output
3968

3969
    def _dummy_pooler_run_task(
3970
3971
        self,
        hidden_states: torch.Tensor,
3972
3973
        task: PoolingTask,
    ) -> PoolerOutput:
3974
3975
3976
3977
3978
3979
3980
3981
3982
3983
3984
        num_tokens = hidden_states.shape[0]
        max_num_reqs = self.scheduler_config.max_num_seqs
        num_reqs = min(num_tokens, max_num_reqs)
        min_tokens_per_req = num_tokens // num_reqs
        num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
        num_scheduled_tokens_list[-1] += num_tokens % num_reqs
        assert sum(num_scheduled_tokens_list) == num_tokens
        assert len(num_scheduled_tokens_list) == num_reqs

        req_num_tokens = num_tokens // num_reqs

3985
        dummy_prompt_lens = torch.tensor(
3986
3987
            num_scheduled_tokens_list,
            device="cpu",
3988
        )
3989
3990
3991
        dummy_token_ids = torch.zeros(
            (num_reqs, req_num_tokens), dtype=torch.int32, device=self.device
        )
3992

3993
        model = cast(VllmModelForPooling, self.get_model())
3994
        dummy_pooling_params = PoolingParams(task=task)
3995
        dummy_pooling_params.verify(task=task, model_config=self.model_config)
3996
        to_update = model.pooler.get_pooling_updates(task)
3997
3998
        to_update.apply(dummy_pooling_params)

3999
        dummy_metadata = PoolingMetadata(
4000
4001
4002
4003
            prompt_lens=dummy_prompt_lens,
            prompt_token_ids=dummy_token_ids,
            pooling_params=[dummy_pooling_params] * num_reqs,
        )
4004

4005
4006
4007
        dummy_metadata.build_pooling_cursor(
            num_scheduled_tokens_list, device=hidden_states.device
        )
4008

4009
        try:
4010
4011
4012
            return model.pooler(
                hidden_states=hidden_states, pooling_metadata=dummy_metadata
            )
4013
        except RuntimeError as e:
4014
            if "out of memory" in str(e):
4015
                raise RuntimeError(
4016
4017
4018
                    "CUDA out of memory occurred when warming up pooler "
                    f"({task=}) with {num_reqs} dummy requests. Please try "
                    "lowering `max_num_seqs` or `gpu_memory_utilization` when "
4019
4020
                    "initializing the engine."
                ) from e
4021
4022
            else:
                raise e
4023
4024
4025
4026
4027
4028
4029

    @torch.inference_mode()
    def _dummy_pooler_run(
        self,
        hidden_states: torch.Tensor,
    ) -> PoolerOutput:
        # Find the task that has the largest output for subsequent steps
4030
4031
4032
        supported_pooling_tasks = self.get_supported_pooling_tasks()

        if not supported_pooling_tasks:
4033
            if self.scheduler_config.enable_chunked_prefill:
4034
4035
4036
4037
4038
4039
4040
4041
4042
4043
4044
4045
4046
4047
4048
4049
                raise RuntimeError(
                    f"Model {self.model_config.model} does not support "
                    "any pooling tasks with chunked prefill enabled. "
                    "Please add --no-enable-chunked-prefill to your "
                    "config or CLI args. See "
                    "https://docs.vllm.ai/en/latest/models/pooling_models.html "
                    "to learn more."
                )
            else:
                raise RuntimeError(
                    f"Model {self.model_config.model} does not support "
                    "any pooling tasks. See "
                    "https://docs.vllm.ai/en/latest/models/pooling_models.html "
                    "to learn more."
                )

4050
        output_size = dict[PoolingTask, float]()
4051
        for task in supported_pooling_tasks:
4052
4053
            # Run a full batch with each task to ensure none of them OOMs
            output = self._dummy_pooler_run_task(hidden_states, task)
4054
            output_size[task] = sum(o.nbytes for o in output)
4055
4056
4057
4058
            del output  # Allow GC

        max_task = max(output_size.items(), key=lambda x: x[1])[0]
        return self._dummy_pooler_run_task(hidden_states, max_task)
4059

4060
    def profile_run(self) -> None:
4061
        # Profile with multimodal encoder & encoder cache.
4062
        if self.supports_mm_inputs:
4063
4064
            mm_config = self.model_config.multimodal_config
            if mm_config is not None and mm_config.skip_mm_profiling:
4065
                logger.info(
4066
                    "Skipping memory profiling for multimodal encoder and "
4067
4068
                    "encoder cache."
                )
4069
4070
4071
4072
4073
4074
4075
4076
            else:
                mm_budget = self.mm_budget
                assert mm_budget is not None

                if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
                    # NOTE: Currently model is profiled with a single non-text
                    # modality with the max possible input tokens even when
                    # it supports multiple.
4077
                    dummy_modality = mm_budget.get_modality_with_max_tokens()
4078
4079
4080
                    max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[
                        dummy_modality
                    ]
4081
4082
4083
4084
4085
4086
4087
4088
4089

                    logger.info(
                        "Encoder cache will be initialized with a budget of "
                        "%s tokens, and profiled with %s %s items of the "
                        "maximum feature size.",
                        encoder_budget,
                        max_mm_items_per_batch,
                        dummy_modality,
                    )
4090

4091
4092
4093
4094
4095
                    # Create dummy batch of multimodal inputs.
                    batched_dummy_mm_inputs = self._get_mm_dummy_batch(
                        dummy_modality,
                        max_mm_items_per_batch,
                    )
4096

4097
                    # Run multimodal encoder.
4098
                    dummy_encoder_outputs = self.model.embed_multimodal(
4099
4100
                        **batched_dummy_mm_inputs
                    )
4101

4102
4103
4104
4105
                    sanity_check_mm_encoder_outputs(
                        dummy_encoder_outputs,
                        expected_num_items=max_mm_items_per_batch,
                    )
4106

4107
4108
4109
4110
4111
4112
4113
4114
4115
4116
                    # NOTE: This happens when encoder cache needs to store
                    # the embeddings that encoder outputs are scattered onto.
                    # In this case we create dummy embeddings of size
                    # (encode_budget, hidden_size) and scatter encoder
                    # output into it.
                    encoder_output_shape = dummy_encoder_outputs[0].shape
                    if encoder_output_shape[0] < encoder_budget:
                        expanded_outputs = []
                        for output in dummy_encoder_outputs:
                            expanded = output.new_zeros(
4117
4118
                                (encoder_budget, encoder_output_shape[-1])
                            )
4119
4120
4121
4122
4123
4124
                            num_tokens = output.shape[0]
                            expanded[:num_tokens].copy_(output)
                            expanded_outputs.append(expanded)

                        dummy_encoder_outputs = expanded_outputs

4125
                    # Cache the dummy encoder outputs.
4126
                    self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
4127

4128
        # Add `is_profile` here to pre-allocate communication buffers
4129
4130
4131
        hidden_states, last_hidden_states = self._dummy_run(
            self.max_num_tokens, is_profile=True
        )
4132
        if get_pp_group().is_last_rank:
4133
4134
4135
4136
            if self.is_pooling_model:
                output = self._dummy_pooler_run(hidden_states)
            else:
                output = self._dummy_sampler_run(last_hidden_states)
4137
        else:
4138
            output = None
4139
        self._sync_device()
4140
        del hidden_states, output
4141
        self.encoder_cache.clear()
4142
        gc.collect()
4143

4144
    def capture_model(self) -> int:
4145
        if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
4146
            logger.warning(
4147
                "Skipping CUDA graph capture. To turn on CUDA graph capture, "
4148
4149
                "ensure `cudagraph_mode` was not manually set to `NONE`"
            )
4150
            return 0
4151

4152
4153
        compilation_counter.num_gpu_runner_capture_triggers += 1

4154
4155
        start_time = time.perf_counter()

4156
4157
4158
4159
4160
4161
4162
4163
4164
4165
4166
4167
4168
4169
        @contextmanager
        def freeze_gc():
            # Optimize garbage collection during CUDA graph capture.
            # Clean up, then freeze all remaining objects from being included
            # in future collections.
            gc.collect()
            should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC
            if should_freeze:
                gc.freeze()
            try:
                yield
            finally:
                if should_freeze:
                    gc.unfreeze()
4170
                    gc.collect()
4171

4172
4173
4174
        # Trigger CUDA graph capture for specific shapes.
        # Capture the large shapes first so that the smaller shapes
        # can reuse the memory pool allocated for the large shapes.
4175
        set_cudagraph_capturing_enabled(True)
4176
        with freeze_gc(), graph_capture(device=self.device):
4177
            start_free_gpu_memory = torch.cuda.mem_get_info()[0]
4178
            cudagraph_mode = self.compilation_config.cudagraph_mode
4179
            assert cudagraph_mode is not None
4180
4181
4182
4183
4184
4185
4186
4187
4188

            if self.lora_config:
                if self.compilation_config.cudagraph_specialize_lora:
                    lora_cases = [True, False]
                else:
                    lora_cases = [True]
            else:
                lora_cases = [False]

4189
4190
            if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
                cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
4191
                # make sure we capture the largest batch size first
4192
4193
4194
                compilation_cases = list(
                    product(reversed(self.cudagraph_batch_sizes), lora_cases)
                )
4195
4196
4197
                self._capture_cudagraphs(
                    compilation_cases,
                    cudagraph_runtime_mode=cudagraph_runtime_mode,
4198
4199
                    uniform_decode=False,
                )
4200

4201
4202
            # Capture full cudagraph for uniform decode batches if we
            # don't already have full mixed prefill-decode cudagraphs.
4203
4204
4205
4206
4207
4208
4209
            if (
                cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
                and cudagraph_mode.separate_routine()
            ):
                max_num_tokens = (
                    self.scheduler_config.max_num_seqs * self.uniform_decode_query_len
                )
4210
                decode_cudagraph_batch_sizes = [
4211
4212
                    x
                    for x in self.cudagraph_batch_sizes
4213
                    if max_num_tokens >= x >= self.uniform_decode_query_len
4214
                ]
4215
4216
4217
                compilation_cases_decode = list(
                    product(reversed(decode_cudagraph_batch_sizes), lora_cases)
                )
4218
4219
4220
                self._capture_cudagraphs(
                    compilation_cases=compilation_cases_decode,
                    cudagraph_runtime_mode=CUDAGraphMode.FULL,
4221
4222
                    uniform_decode=True,
                )
4223

4224
4225
4226
            torch.cuda.synchronize()
            end_free_gpu_memory = torch.cuda.mem_get_info()[0]

4227
4228
4229
        # Disable cudagraph capturing globally, so any unexpected cudagraph
        # capturing will be detected and raise an error after here.
        # Note: We don't put it into graph_capture context manager because
4230
        # we may do lazy capturing in future that still allows capturing
4231
4232
        # after here.
        set_cudagraph_capturing_enabled(False)
4233
4234
4235
4236
4237

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
        # This usually takes 5~20 seconds.
4238
        logger.info_once(
4239
4240
4241
            "Graph capturing finished in %.0f secs, took %.2f GiB",
            elapsed_time,
            cuda_graph_size / (1 << 30),
4242
            scope="local",
4243
        )
4244
        return cuda_graph_size
4245

4246
4247
    def _capture_cudagraphs(
        self,
4248
        compilation_cases: list[tuple[int, bool]],
4249
4250
4251
4252
4253
4254
4255
        cudagraph_runtime_mode: CUDAGraphMode,
        uniform_decode: bool,
    ):
        assert (
            cudagraph_runtime_mode != CUDAGraphMode.NONE
            and cudagraph_runtime_mode.valid_runtime_modes()
        ), f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}"
4256
4257
4258
4259
4260
4261
4262
4263

        # Only rank 0 should print progress bar during capture
        if is_global_first_rank():
            compilation_cases = tqdm(
                compilation_cases,
                disable=not self.load_config.use_tqdm_on_load,
                desc="Capturing CUDA graphs ({}, {})".format(
                    "decode" if uniform_decode else "mixed prefill-decode",
4264
4265
4266
                    cudagraph_runtime_mode.name,
                ),
            )
4267

4268
        # We skip EPLB here since we don't want to record dummy metrics
4269
        for num_tokens, activate_lora in compilation_cases:
4270
            # We currently only capture ubatched graphs when its a FULL
4271
4272
4273
            # cudagraph, a uniform decode batch, and the number of tokens
            # is above the threshold. Otherwise we just capture a non-ubatched
            # version of the graph
4274
4275
4276
4277
            allow_microbatching = (
                self.parallel_config.enable_dbo
                and cudagraph_runtime_mode == CUDAGraphMode.FULL
                and uniform_decode
4278
4279
4280
4281
4282
                and check_ubatch_thresholds(
                    config=self.vllm_config.parallel_config,
                    num_tokens=num_tokens,
                    uniform_decode=uniform_decode,
                )
4283
            )
4284

4285
4286
4287
4288
4289
4290
            for _ in range(self.compilation_config.cudagraph_num_of_warmups):
                # Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
                # But be careful, warm up with `NONE`is orthogonal to
                # if we want to warm up attention or not. This is
                # different from the case where `FULL` implies capture
                # attention while `PIECEWISE` implies no attention.
4291
4292
4293
4294
4295
4296
4297
4298
4299
                force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL
                self._dummy_run(
                    num_tokens,
                    cudagraph_runtime_mode=CUDAGraphMode.NONE,
                    force_attention=force_attention,
                    uniform_decode=uniform_decode,
                    allow_microbatching=allow_microbatching,
                    skip_eplb=True,
                    remove_lora=False,
4300
                    activate_lora=activate_lora,
4301
4302
4303
4304
4305
4306
4307
4308
                )
            self._dummy_run(
                num_tokens,
                cudagraph_runtime_mode=cudagraph_runtime_mode,
                uniform_decode=uniform_decode,
                allow_microbatching=allow_microbatching,
                skip_eplb=True,
                remove_lora=False,
4309
                activate_lora=activate_lora,
4310
            )
4311
        self.maybe_remove_all_loras(self.lora_config)
4312

4313
4314
4315
4316
    def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize the attention backends and attention metadata builders.
        """
4317
        assert len(self.attn_groups) == 0, "Attention backends are already initialized"
4318

4319
4320
4321
4322
4323
4324
        class AttentionGroupKey(NamedTuple):
            attn_backend: type[AttentionBackend]
            kv_cache_spec: KVCacheSpec

        def get_attn_backends_for_group(
            kv_cache_group_spec: KVCacheGroupSpec,
4325
        ) -> tuple[dict[AttentionGroupKey, list[str]], set[type[AttentionBackend]]]:
4326
            layer_type = cast(type[Any], AttentionLayerBase)
4327
            layers = get_layers_from_vllm_config(
4328
                self.vllm_config, layer_type, kv_cache_group_spec.layer_names
4329
            )
4330
4331
            attn_backends = {}
            attn_backend_layers = defaultdict(list)
4332
            # Dedupe based on full class name; this is a bit safer than
4333
4334
4335
4336
            # using the class itself as the key because when we create dynamic
            # attention backend subclasses (e.g. ChunkedLocalAttention) unless
            # they are cached correctly, there will be different objects per
            # layer.
4337
            for layer_name in kv_cache_group_spec.layer_names:
4338
                attn_backend = layers[layer_name].get_attn_backend()
4339
4340
4341
4342

                if layer_name in self.kv_sharing_fast_prefill_eligible_layers:
                    attn_backend = create_fast_prefill_custom_backend(
                        "FastPrefill",
4343
                        attn_backend,  # type: ignore[arg-type]
4344
4345
                    )

4346
4347
4348
                full_cls_name = attn_backend.full_cls_name()
                layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec
                if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
4349
                    layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name]
4350
                key = (full_cls_name, layer_kv_cache_spec)
4351
4352
4353
                attn_backends[key] = AttentionGroupKey(
                    attn_backend, layer_kv_cache_spec
                )
4354
                attn_backend_layers[key].append(layer_name)
4355
4356
4357
4358
            return (
                {attn_backends[k]: v for k, v in attn_backend_layers.items()},
                set(group_key.attn_backend for group_key in attn_backends.values()),
            )
4359
4360

        def create_attn_groups(
4361
            attn_backends_map: dict[AttentionGroupKey, list[str]],
4362
            kv_cache_group_id: int,
4363
4364
        ) -> list[AttentionGroup]:
            attn_groups: list[AttentionGroup] = []
4365
            for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
4366
                attn_group = AttentionGroup(
4367
                    attn_backend,
4368
                    layer_names,
4369
                    kv_cache_spec,
4370
                    kv_cache_group_id,
4371
4372
                )

4373
4374
4375
                attn_groups.append(attn_group)
            return attn_groups

4376
        attention_backend_maps = []
4377
        attention_backend_list = []
4378
        for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
4379
            attn_backends = get_attn_backends_for_group(kv_cache_group_spec)
4380
            attention_backend_maps.append(attn_backends[0])
4381
            attention_backend_list.append(attn_backends[1])
4382
4383

        # Resolve cudagraph_mode before actually initialize metadata_builders
4384
4385
4386
        self._check_and_update_cudagraph_mode(
            attention_backend_list, kv_cache_config.kv_cache_groups
        )
4387

4388
4389
        for i, attn_backend_map in enumerate(attention_backend_maps):
            self.attn_groups.append(create_attn_groups(attn_backend_map, i))
4390

4391
4392
4393
4394
4395
4396
4397
4398
4399
4400
4401
4402
4403
4404
4405
4406
4407
4408
    def initialize_metadata_builders(
        self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
    ) -> None:
        """
        Create the metadata builders for all KV cache groups and attn groups.
        """
        for kv_cache_group_id in range(len(kv_cache_config.kv_cache_groups)):
            for attn_group in self.attn_groups[kv_cache_group_id]:
                attn_group.create_metadata_builders(
                    self.vllm_config,
                    self.device,
                    kernel_block_sizes[kv_cache_group_id]
                    if kv_cache_group_id < len(kernel_block_sizes)
                    else None,
                    num_metadata_builders=1
                    if not self.parallel_config.enable_dbo
                    else 2,
                )
co63oc's avatar
co63oc committed
4409
        # Calculate reorder batch threshold (if needed)
4410
4411
        # Note (tdoublep): do this *after* constructing builders,
        # because some of them change the threshold at init time.
4412
4413
        self.calculate_reorder_batch_threshold()

4414
    def _check_and_update_cudagraph_mode(
4415
4416
4417
        self,
        attention_backends: list[set[type[AttentionBackend]]],
        kv_cache_groups: list[KVCacheGroupSpec],
4418
    ) -> None:
4419
        """
4420
        Resolve the cudagraph_mode when there are multiple attention
4421
        groups with potential conflicting CUDA graph support.
4422
4423
4424
        Then initialize the cudagraph_dispatcher based on the resolved
        cudagraph_mode.
        """
4425
        min_cg_support = AttentionCGSupport.ALWAYS
4426
        min_cg_backend_name = None
4427

4428
4429
4430
4431
4432
4433
4434
4435
4436
4437
4438
4439
        for attn_backend_set, kv_cache_group in zip(
            attention_backends, kv_cache_groups
        ):
            for attn_backend in attn_backend_set:
                builder_cls = attn_backend.get_builder_cls()

                cg_support = builder_cls.get_cudagraph_support(
                    self.vllm_config, kv_cache_group.kv_cache_spec
                )
                if cg_support.value < min_cg_support.value:
                    min_cg_support = cg_support
                    min_cg_backend_name = attn_backend.__name__
4440
4441
        # Flexible resolve the cudagraph mode
        cudagraph_mode = self.compilation_config.cudagraph_mode
4442
        assert cudagraph_mode is not None
4443
        # check cudagraph for mixed batch is supported
4444
4445
4446
4447
4448
4449
        if (
            cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL
            and min_cg_support != AttentionCGSupport.ALWAYS
        ):
            msg = (
                f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
4450
                f"with {min_cg_backend_name} backend (support: "
4451
4452
                f"{min_cg_support})"
            )
4453
4454
            if min_cg_support == AttentionCGSupport.NEVER:
                # if not supported any full cudagraphs, just raise it.
4455
4456
                msg += (
                    "; please try cudagraph_mode=PIECEWISE, and "
4457
                    "make sure compilation mode is VLLM_COMPILE"
4458
                )
4459
4460
4461
4462
4463
                raise ValueError(msg)

            # attempt to resolve the full cudagraph related mode
            if self.compilation_config.splitting_ops_contain_attention():
                msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE"
4464
                cudagraph_mode = self.compilation_config.cudagraph_mode = (
4465
                    CUDAGraphMode.FULL_AND_PIECEWISE
4466
                )
4467
4468
            else:
                msg += "; setting cudagraph_mode=FULL_DECODE_ONLY"
4469
                cudagraph_mode = self.compilation_config.cudagraph_mode = (
4470
                    CUDAGraphMode.FULL_DECODE_ONLY
4471
                )
4472
4473
            logger.warning(msg)

4474
        # check that if we are doing decode full-cudagraphs it is supported
4475
4476
4477
4478
4479
4480
        if (
            cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
            and min_cg_support == AttentionCGSupport.NEVER
        ):
            msg = (
                f"CUDAGraphMode.{cudagraph_mode.name} is not supported "
4481
                f"with {min_cg_backend_name} backend (support: "
4482
4483
                f"{min_cg_support})"
            )
4484
            if self.compilation_config.mode == CompilationMode.VLLM_COMPILE and (
4485
4486
4487
4488
4489
                self.compilation_config.splitting_ops_contain_attention()
                or self.compilation_config.use_inductor_graph_partition
            ):
                msg += (
                    "; setting cudagraph_mode=PIECEWISE because "
4490
                    "attention is compiled piecewise"
4491
4492
                )
                cudagraph_mode = self.compilation_config.cudagraph_mode = (
4493
                    CUDAGraphMode.PIECEWISE
4494
                )
4495
            else:
4496
4497
                msg += (
                    "; setting cudagraph_mode=NONE because "
4498
                    "attention is not compiled piecewise"
4499
4500
                )
                cudagraph_mode = self.compilation_config.cudagraph_mode = (
4501
                    CUDAGraphMode.NONE
4502
                )
4503
4504
            logger.warning(msg)

4505
4506
        # check that if we are doing spec-decode + decode full-cudagraphs it is
        # supported
4507
4508
4509
4510
4511
4512
4513
4514
        if (
            cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
            and self.uniform_decode_query_len > 1
            and min_cg_support.value < AttentionCGSupport.UNIFORM_BATCH.value
        ):
            msg = (
                f"CUDAGraphMode.{cudagraph_mode.name} is not supported"
                f" with spec-decode for attention backend "
4515
                f"{min_cg_backend_name} (support: {min_cg_support})"
4516
            )
4517
4518
            if self.compilation_config.splitting_ops_contain_attention():
                msg += "; setting cudagraph_mode=PIECEWISE"
4519
                cudagraph_mode = self.compilation_config.cudagraph_mode = (
4520
                    CUDAGraphMode.PIECEWISE
4521
                )
4522
4523
            else:
                msg += "; setting cudagraph_mode=NONE"
4524
                cudagraph_mode = self.compilation_config.cudagraph_mode = (
4525
                    CUDAGraphMode.NONE
4526
                )
4527
4528
4529
4530
            logger.warning(msg)

        # double check that we can support full cudagraph if they are requested
        # even after automatic downgrades
4531
4532
4533
4534
4535
4536
        if (
            cudagraph_mode.has_full_cudagraphs()
            and min_cg_support == AttentionCGSupport.NEVER
        ):
            raise ValueError(
                f"CUDAGraphMode.{cudagraph_mode.name} is not "
4537
                f"supported with {min_cg_backend_name} backend ("
4538
4539
                f"support:{min_cg_support}) "
                "; please try cudagraph_mode=PIECEWISE, "
4540
                "and make sure compilation mode is VLLM_COMPILE"
4541
            )
4542

4543
4544
4545
4546
4547
4548
4549
4550
4551
4552
4553
4554
4555
4556
        # if we have dedicated decode cudagraphs, and spec-decode is enabled,
        # we need to adjust the cudagraph sizes to be a multiple of the uniform
        # decode query length to avoid: https://github.com/vllm-project/vllm/issues/28207
        # temp-fix: https://github.com/vllm-project/vllm/issues/28207#issuecomment-3504004536
        # Will be removed in the near future when we have seperate cudagraph capture
        # sizes for decode and mixed prefill-decode.
        if (
            cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
            and cudagraph_mode.separate_routine()
            and self.uniform_decode_query_len > 1
        ):
            self.compilation_config.adjust_cudagraph_sizes_for_spec_decode(
                self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size
            )
4557
4558
4559
4560
            capture_sizes = self.compilation_config.cudagraph_capture_sizes
            self.cudagraph_batch_sizes = (
                capture_sizes if capture_sizes is not None else []
            )
4561

4562
4563
        # Trigger cudagraph dispatching keys initialization after
        # resolved cudagraph mode.
4564
4565
        cudagraph_mode = self.compilation_config.cudagraph_mode
        assert cudagraph_mode is not None
4566
        self.cudagraph_dispatcher.initialize_cudagraph_keys(
4567
            cudagraph_mode, self.uniform_decode_query_len
4568
        )
4569

4570
4571
    def calculate_reorder_batch_threshold(self) -> None:
        """
4572
4573
4574
4575
        Choose the minimum reorder batch threshold from all attention groups.
        Backends should be able to support lower threshold then what they request
        just may have a performance penalty due to that backend treating decodes
        as prefills.
4576
        """
4577
4578
        min_none_high = lambda a, b: a if b is None else b if a is None else min(a, b)

4579
        reorder_batch_thresholds: list[int | None] = [
4580
4581
4582
            group.get_metadata_builder().reorder_batch_threshold
            for group in self._attn_group_iterator()
        ]
4583
4584
4585
4586
4587
        # If there are no attention groups (attention-free model) or no backend
        # reports a threshold, leave reordering disabled.
        if len(reorder_batch_thresholds) == 0:
            self.reorder_batch_threshold = None
            return
4588
        self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds)  # type: ignore[assignment]
4589

4590
4591
4592
    @staticmethod
    def select_common_block_size(
        kv_manager_block_size: int, attn_groups: list[AttentionGroup]
4593
4594
    ) -> int:
        """
4595
4596
4597
4598
4599
        Select a block size that is supported by all backends and is a factor of
        kv_manager_block_size.

        If kv_manager_block_size is supported by all backends, return it directly.
        Otherwise, return the max supported size.
4600
4601
4602
4603
4604
4605

        Args:
            kv_manager_block_size: Block size of KV cache
            attn_groups: List of attention groups

        Returns:
4606
            The selected block size
4607
4608

        Raises:
4609
            ValueError: If no valid block size found
4610
4611
        """

4612
4613
4614
4615
4616
4617
4618
4619
        def block_size_is_supported(
            backends: list[type[AttentionBackend]], block_size: int
        ) -> bool:
            """
            Check if the block size is supported by all backends.
            """
            for backend in backends:
                is_supported = False
4620
                for supported_size in backend.get_supported_kernel_block_sizes():
4621
4622
4623
4624
4625
4626
4627
4628
4629
4630
4631
4632
4633
4634
4635
4636
4637
4638
4639
4640
4641
4642
4643
4644
4645
4646
4647
4648
4649
4650
                    if isinstance(supported_size, int):
                        if block_size == supported_size:
                            is_supported = True
                    elif isinstance(supported_size, MultipleOf):
                        if block_size % supported_size.base == 0:
                            is_supported = True
                    else:
                        raise ValueError(f"Unknown supported size: {supported_size}")
                if not is_supported:
                    return False
            return True

        backends = [group.backend for group in attn_groups]

        # Case 1: if the block_size of kv cache manager is supported by all backends,
        # return it directly
        if block_size_is_supported(backends, kv_manager_block_size):
            return kv_manager_block_size

        # Case 2: otherwise, the block_size must be an `int`-format supported size of
        # at least one backend. Iterate over all `int`-format supported sizes in
        # descending order and return the first one that is supported by all backends.
        # Simple proof:
        # If the supported size b is in MultipleOf(x_i) format for all attention
        # backends i, and b a factor of kv_manager_block_size, then
        # kv_manager_block_size also satisfies MultipleOf(x_i) for all i. We will
        # return kv_manager_block_size in case 1.
        all_int_supported_sizes = set(
            supported_size
            for backend in backends
4651
            for supported_size in backend.get_supported_kernel_block_sizes()
4652
4653
            if isinstance(supported_size, int)
        )
4654

4655
4656
4657
4658
4659
4660
        for supported_size in sorted(all_int_supported_sizes, reverse=True):
            if kv_manager_block_size % supported_size != 0:
                continue
            if block_size_is_supported(backends, supported_size):
                return supported_size
        raise ValueError(f"No common block size for {kv_manager_block_size}. ")
4661

4662
4663
4664
    def may_reinitialize_input_batch(
        self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
    ) -> None:
4665
4666
4667
4668
4669
4670
4671
        """
        Re-initialize the input batch if the block sizes are different from
        `[self.cache_config.block_size]`. This usually happens when there
        are multiple KV cache groups.

        Args:
            kv_cache_config: The KV cache configuration.
4672
            kernel_block_sizes: The kernel block sizes for each KV cache group.
4673
4674
4675
4676
        """
        block_sizes = [
            kv_cache_group.kv_cache_spec.block_size
            for kv_cache_group in kv_cache_config.kv_cache_groups
4677
            if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
4678
        ]
4679
4680
4681
4682

        if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [
            self.cache_config.block_size
        ]:
4683
4684
4685
            assert self.cache_config.cpu_offload_gb == 0, (
                "Cannot re-initialize the input batch when CPU weight "
                "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 "  # noqa: E501
4686
4687
                "for more details."
            )
4688
4689
            self.input_batch = InputBatch(
                max_num_reqs=self.max_num_reqs,
4690
                max_model_len=max(self.max_model_len, self.max_encoder_len),
4691
4692
4693
4694
4695
                max_num_batched_tokens=self.max_num_tokens,
                device=self.device,
                pin_memory=self.pin_memory,
                vocab_size=self.model_config.get_vocab_size(),
                block_sizes=block_sizes,
4696
                kernel_block_sizes=kernel_block_sizes,
4697
                is_spec_decode=bool(self.vllm_config.speculative_config),
4698
                logitsprocs=self.input_batch.logitsprocs,
4699
                logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids,
4700
                is_pooling_model=self.is_pooling_model,
4701
                num_speculative_tokens=self.num_spec_tokens,
4702
4703
            )

4704
    def _allocate_kv_cache_tensors(
4705
4706
        self, kv_cache_config: KVCacheConfig
    ) -> dict[str, torch.Tensor]:
4707
        """
4708
4709
4710
        Initializes the KV cache buffer with the correct size. The buffer needs
        to be reshaped to the desired shape before being used by the models.

4711
        Args:
4712
            kv_cache_config: The KV cache config
4713
        Returns:
4714
            dict[str, torch.Tensor]: A map between layer names to their
4715
            corresponding memory buffer for KV cache.
4716
        """
4717
4718
        kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
        for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
4719
4720
4721
            tensor = torch.zeros(
                kv_cache_tensor.size, dtype=torch.int8, device=self.device
            )
4722
4723
4724
4725
4726
            for layer_name in kv_cache_tensor.shared_by:
                kv_cache_raw_tensors[layer_name] = tensor

        layer_names = set()
        for group in kv_cache_config.kv_cache_groups:
4727
4728
4729
4730
            for layer_name in group.layer_names:
                if layer_name in self.runner_only_attn_layers:
                    continue
                layer_names.add(layer_name)
4731
4732
4733
        assert layer_names == set(kv_cache_raw_tensors.keys()), (
            "Some layers are not correctly initialized"
        )
4734
4735
        return kv_cache_raw_tensors

4736
4737
4738
    def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
        return itertools.chain.from_iterable(self.attn_groups)

4739
    def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]:
4740
4741
        if not self.kv_cache_config.kv_cache_groups:
            return
4742
4743
        for attn_groups in self.attn_groups:
            yield from attn_groups
4744

4745
4746
4747
4748
4749
4750
4751
4752
4753
4754
4755
4756
4757
4758
4759
    def _prepare_kernel_block_sizes(self, kv_cache_config: KVCacheConfig) -> list[int]:
        """
        Generate kernel_block_sizes that matches each block_size.

        For attention backends that support virtual block splitting,
        use the supported block sizes from the backend.
        For other backends (like Mamba), use the same block size (no splitting).

        Args:
            kv_cache_config: The KV cache configuration.

        Returns:
            list[int]: List of kernel block sizes for each cache group.
        """
        kernel_block_sizes = []
4760
        for kv_cache_gid, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
4761
4762
4763
4764
4765
4766
            kv_cache_spec = kv_cache_group.kv_cache_spec
            if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
                # All layers in the UniformTypeKVCacheSpecs have the same type,
                # Pick an arbitrary one to dispatch.
                kv_cache_spec = next(iter(kv_cache_spec.kv_cache_specs.values()))
            if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
4767
                continue
4768
            elif isinstance(kv_cache_spec, AttentionSpec):
4769
4770
4771
                # This is an attention backend that supports virtual
                # block splitting. Get the supported block sizes from
                # all backends in the group.
4772
                attn_groups = self.attn_groups[kv_cache_gid]
4773
                kv_manager_block_size = kv_cache_group.kv_cache_spec.block_size
4774
                selected_kernel_size = self.select_common_block_size(
4775
4776
4777
                    kv_manager_block_size, attn_groups
                )
                kernel_block_sizes.append(selected_kernel_size)
4778
            elif isinstance(kv_cache_spec, MambaSpec):
4779
4780
                # This is likely Mamba or other non-attention cache,
                # no splitting.
4781
                kernel_block_sizes.append(kv_cache_spec.block_size)
4782
4783
4784
4785
4786
4787
            else:
                raise NotImplementedError(
                    f"unknown kv cache spec {kv_cache_group.kv_cache_spec}"
                )
        return kernel_block_sizes

4788
4789
4790
4791
    def _reshape_kv_cache_tensors(
        self,
        kv_cache_config: KVCacheConfig,
        kv_cache_raw_tensors: dict[str, torch.Tensor],
4792
        kernel_block_sizes: list[int],
4793
    ) -> dict[str, torch.Tensor]:
4794
        """
4795
        Reshape the KV cache tensors to the desired shape and dtype.
4796

4797
        Args:
4798
4799
            kv_cache_config: The KV cache config
            kv_cache_raw_tensors: The KV cache buffer of each layer, with
4800
                correct size but uninitialized shape.
4801
            kernel_block_sizes: The kernel block sizes for each KV cache group.
4802
        Returns:
4803
            Dict[str, torch.Tensor]: A map between layer names to their
4804
4805
            corresponding memory buffer for KV cache.
        """
4806
        kv_caches: dict[str, torch.Tensor] = {}
4807
        has_attn, has_mamba = False, False
4808
4809
        for group in self._kv_cache_spec_attn_group_iterator():
            kv_cache_spec = group.kv_cache_spec
4810
            attn_backend = group.backend
4811
4812
4813
4814
            if group.kv_cache_group_id == len(kernel_block_sizes):
                # There may be a last group for layers without kv cache.
                continue
            kernel_block_size = kernel_block_sizes[group.kv_cache_group_id]
4815
            for layer_name in group.layer_names:
4816
4817
                if layer_name in self.runner_only_attn_layers:
                    continue
4818
4819
                raw_tensor = kv_cache_raw_tensors[layer_name]
                assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
4820
                num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
4821
                if isinstance(kv_cache_spec, AttentionSpec):
4822
                    has_attn = True
4823
4824
                    num_blocks_per_kv_block = (
                        kv_cache_spec.block_size // kernel_block_size
4825
4826
4827
                    )
                    kernel_num_blocks = num_blocks * num_blocks_per_kv_block

4828
                    kv_cache_shape = attn_backend.get_kv_cache_shape(
4829
                        kernel_num_blocks,
4830
                        kernel_block_size,
4831
4832
                        kv_cache_spec.num_kv_heads,
                        kv_cache_spec.head_size,
4833
4834
                        cache_dtype_str=self.cache_config.cache_dtype,
                    )
4835
                    dtype = kv_cache_spec.dtype
4836
                    try:
4837
                        kv_cache_stride_order = attn_backend.get_kv_cache_stride_order()
4838
                        assert len(kv_cache_stride_order) == len(kv_cache_shape)
4839
                    except (AttributeError, NotImplementedError):
4840
                        kv_cache_stride_order = tuple(range(len(kv_cache_shape)))
4841
4842
4843
4844
4845
                    # The allocation respects the backend-defined stride order
                    # to ensure the semantic remains consistent for each
                    # backend. We first obtain the generic kv cache shape and
                    # then permute it according to the stride order which could
                    # result in a non-contiguous tensor.
4846
4847
4848
                    kv_cache_shape = tuple(
                        kv_cache_shape[i] for i in kv_cache_stride_order
                    )
4849
4850
4851
4852
4853
                    # Maintain original KV shape view.
                    inv_order = [
                        kv_cache_stride_order.index(i)
                        for i in range(len(kv_cache_stride_order))
                    ]
4854
4855
4856
4857
4858
4859
                    kv_caches[layer_name] = (
                        kv_cache_raw_tensors[layer_name]
                        .view(dtype)
                        .view(kv_cache_shape)
                        .permute(*inv_order)
                    )
Chen Zhang's avatar
Chen Zhang committed
4860
                elif isinstance(kv_cache_spec, MambaSpec):
4861
                    has_mamba = True
Chen Zhang's avatar
Chen Zhang committed
4862
4863
                    raw_tensor = kv_cache_raw_tensors[layer_name]
                    state_tensors = []
4864
                    storage_offset_bytes = 0
4865
                    for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes):
4866
4867
                        dtype_size = get_dtype_size(dtype)
                        num_element_per_page = (
4868
4869
                            kv_cache_spec.page_size_bytes // dtype_size
                        )
Chen Zhang's avatar
Chen Zhang committed
4870
                        target_shape = (num_blocks, *shape)
4871
4872
                        stride = torch.empty(target_shape).stride()
                        target_stride = (num_element_per_page, *stride[1:])
4873
                        assert storage_offset_bytes % dtype_size == 0
4874
4875
4876
4877
                        tensor = torch.as_strided(
                            raw_tensor.view(dtype),
                            size=target_shape,
                            stride=target_stride,
4878
                            storage_offset=storage_offset_bytes // dtype_size,
4879
                        )
Chen Zhang's avatar
Chen Zhang committed
4880
                        state_tensors.append(tensor)
4881
                        storage_offset_bytes += stride[0] * dtype_size
4882
4883

                    kv_caches[layer_name] = state_tensors
4884
                else:
4885
                    raise NotImplementedError
4886
4887

        if has_attn and has_mamba:
4888
            self._update_hybrid_attention_mamba_layout(kv_caches)
4889

4890
4891
        return kv_caches

4892
    def _update_hybrid_attention_mamba_layout(
4893
4894
        self, kv_caches: dict[str, torch.Tensor]
    ) -> None:
4895
        """
4896
4897
        Update the layout of attention layers from (2, num_blocks, ...) to
        (num_blocks, 2, ...).
4898
4899

        Args:
4900
            kv_caches: The KV cache buffer of each layer.
4901
4902
        """

4903
4904
        for group in self._kv_cache_spec_attn_group_iterator():
            kv_cache_spec = group.kv_cache_spec
4905
            for layer_name in group.layer_names:
4906
                kv_cache = kv_caches[layer_name]
4907
4908
4909
4910
                if isinstance(kv_cache_spec, AttentionSpec) and kv_cache.shape[0] == 2:
                    assert kv_cache.shape[1] != 2, (
                        "Fail to determine whether the layout is "
                        "(2, num_blocks, ...) or (num_blocks, 2, ...) for "
4911
                        f"a tensor of shape {kv_cache.shape}"
4912
                    )
4913
                    hidden_size = kv_cache.shape[2:].numel()
4914
4915
4916
4917
                    kv_cache.as_strided_(
                        size=kv_cache.shape,
                        stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]),
                    )
4918

4919
    def initialize_kv_cache_tensors(
4920
        self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
4921
    ) -> dict[str, torch.Tensor]:
4922
4923
4924
4925
4926
        """
        Initialize the memory buffer for KV cache.

        Args:
            kv_cache_config: The KV cache config
4927
4928
            kernel_block_sizes: The kernel block sizes for each KV cache group.

4929
        Returns:
4930
            Dict[str, torch.Tensor]: A map between layer names to their
4931
4932
            corresponding memory buffer for KV cache.
        """
4933
4934
4935
4936
4937
4938
4939
4940
4941
4942
4943
4944
4945
4946
4947
4948
4949
4950
4951
4952
4953
4954
4955
4956

        # Try creating KV caches optimized for kv-connector transfers
        cache_dtype = self.cache_config.cache_dtype
        if self.use_uniform_kv_cache(self.attn_groups, cache_dtype):
            kv_caches, cross_layers_kv_cache, attn_backend = (
                self.allocate_uniform_kv_caches(
                    kv_cache_config,
                    self.attn_groups,
                    cache_dtype,
                    self.device,
                    kernel_block_sizes,
                )
            )
            self.cross_layers_kv_cache = cross_layers_kv_cache
            self.cross_layers_attn_backend = attn_backend
        else:
            # Fallback to the general case
            # Initialize the memory buffer for KV cache
            kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)

            # Change the memory buffer to the desired shape
            kv_caches = self._reshape_kv_cache_tensors(
                kv_cache_config, kv_cache_raw_tensors, kernel_block_sizes
            )
4957

4958
        # Set up cross-layer KV cache sharing
4959
4960
        for layer_name, target_layer_name in self.shared_kv_cache_layers.items():
            logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name)
4961
4962
            kv_caches[layer_name] = kv_caches[target_layer_name]

4963
4964
4965
4966
4967
4968
4969
4970
4971
        num_attn_module = (
            2 if self.model_config.hf_config.model_type == "longcat_flash" else 1
        )
        bind_kv_cache(
            kv_caches,
            self.compilation_config.static_forward_context,
            self.kv_caches,
            num_attn_module,
        )
4972
4973
4974
        return kv_caches

    def maybe_add_kv_sharing_layers_to_kv_cache_groups(
4975
4976
        self, kv_cache_config: KVCacheConfig
    ) -> None:
4977
4978
4979
4980
4981
4982
4983
4984
4985
4986
4987
4988
4989
4990
4991
4992
4993
4994
        """
        Add layers that re-use KV cache to KV cache group of its target layer.
        Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()`
        """
        if not self.shared_kv_cache_layers:
            # No cross-layer KV sharing, return
            return

        add_kv_sharing_layers_to_kv_cache_groups(
            self.shared_kv_cache_layers,
            kv_cache_config.kv_cache_groups,
            self.runner_only_attn_layers,
        )

        if self.cache_config.kv_sharing_fast_prefill:
            # In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other
            # similar KV sharing setups, only the layers that generate KV caches
            # are involved in the prefill phase, enabling prefill to early exit.
4995
            attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
4996
4997
            for layer_name in reversed(attn_layers):
                if layer_name in self.shared_kv_cache_layers:
4998
                    self.kv_sharing_fast_prefill_eligible_layers.add(layer_name)
4999
5000
                else:
                    break
5001

5002
5003
5004
5005
5006
5007
5008
    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize KV cache based on `kv_cache_config`.
        Args:
            kv_cache_config: Configuration for the KV cache, including the KV
            cache size of each layer
        """
5009
        kv_cache_config = deepcopy(kv_cache_config)
5010
        self.kv_cache_config = kv_cache_config
5011
        self.may_add_encoder_only_layers_to_kv_cache_config()
5012
        self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
5013
        self.initialize_attn_backend(kv_cache_config)
5014
5015
5016
5017
5018
5019
        # The kernel block size for all KV cache groups. For example, if
        # kv_cache_manager uses block_size 256 for a given group, but the attention
        # backends for that group only supports block_size 64, we will return
        # kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
        # tokens each.
        kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
5020
5021
5022
5023

        # create metadata builders
        self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes)

5024
        # Reinitialize need to after initialize_attn_backend
5025
5026
5027
5028
        self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes)
        kv_caches = self.initialize_kv_cache_tensors(
            kv_cache_config, kernel_block_sizes
        )
5029

5030
5031
5032
5033
5034
5035
        if self.speculative_config and self.speculative_config.use_eagle():
            assert isinstance(self.drafter, EagleProposer)
            # validate all draft model layers belong to the same kv cache
            # group
            self.drafter.validate_same_kv_cache_group(kv_cache_config)

Robert Shaw's avatar
Robert Shaw committed
5036
        if has_kv_transfer_group():
5037
            kv_transfer_group = get_kv_transfer_group()
5038
5039
5040
5041
5042
5043
5044
            if self.cross_layers_kv_cache is not None:
                assert self.cross_layers_attn_backend is not None
                kv_transfer_group.register_cross_layers_kv_cache(
                    self.cross_layers_kv_cache, self.cross_layers_attn_backend
                )
            else:
                kv_transfer_group.register_kv_caches(kv_caches)
5045
            kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks)
Robert Shaw's avatar
Robert Shaw committed
5046

5047
        if self.dcp_world_size > 1:
5048
5049
            layer_type = cast(type[Any], AttentionLayerBase)
            layers = get_layers_from_vllm_config(self.vllm_config, layer_type)
5050
            for layer in layers.values():
5051
5052
5053
5054
                layer_impl = getattr(layer, "impl", None)
                if layer_impl is None:
                    continue
                assert layer_impl.need_to_return_lse_for_decode, (
5055
5056
                    "DCP requires attention impls to return"
                    " the softmax lse for decode, but the impl "
5057
                    f"{layer_impl.__class__.__name__} "
5058
5059
                    "does not return the softmax lse for decode."
                )
5060

5061
5062
5063
5064
5065
    def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
        """
        Add encoder-only layers to the KV cache config.
        """
        block_size = self.vllm_config.cache_config.block_size
5066
        encoder_only_attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list)
5067
5068
5069
        attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
        for layer_name, attn_module in attn_layers.items():
            if attn_module.attn_type == AttentionType.ENCODER_ONLY:
5070
                attn_spec: AttentionSpec = EncoderOnlyAttentionSpec(
5071
5072
5073
                    block_size=block_size,
                    num_kv_heads=attn_module.num_kv_heads,
                    head_size=attn_module.head_size,
5074
5075
                    dtype=self.kv_cache_dtype,
                )
5076
5077
5078
                encoder_only_attn_specs[attn_spec].append(layer_name)
                self.runner_only_attn_layers.add(layer_name)
        if len(encoder_only_attn_specs) > 0:
5079
5080
5081
            assert len(encoder_only_attn_specs) == 1, (
                "Only support one encoder-only attention spec now"
            )
5082
5083
            spec, layer_names = encoder_only_attn_specs.popitem()
            self.kv_cache_config.kv_cache_groups.append(
5084
5085
                KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec)
            )
5086

5087
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
5088
        """
5089
        Generates the KVCacheSpec by parsing the kv cache format from each
5090
5091
        Attention module in the static forward context.
        Returns:
5092
            KVCacheSpec: A dictionary mapping layer names to their KV cache
5093
5094
            format. Layers that do not need KV cache are not included.
        """
5095
5096
        if has_ec_transfer() and get_ec_transfer().is_producer:
            return {}
5097
        kv_cache_spec: dict[str, KVCacheSpec] = {}
5098
5099
        layer_type = cast(type[Any], AttentionLayerBase)
        attn_layers = get_layers_from_vllm_config(self.vllm_config, layer_type)
Chen Zhang's avatar
Chen Zhang committed
5100
        for layer_name, attn_module in attn_layers.items():
5101
5102
5103
5104
5105
5106
5107
5108
5109
5110
5111
5112
5113
5114
5115
            if isinstance(attn_module, Attention) and (
                kv_tgt_layer := attn_module.kv_sharing_target_layer_name
            ):
                # The layer doesn't need its own KV cache and will use that of
                # the target layer. We skip creating a KVCacheSpec for it, so
                # that KV cache management logic will act as this layer does
                # not exist, and doesn't allocate KV cache for the layer. This
                # enables the memory saving of cross-layer kv sharing, allowing
                # a given amount of memory to accommodate longer context lengths
                # or enable more requests to be processed simultaneously.
                self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
                continue
            # Skip modules that don't need KV cache (eg encoder-only attention)
            if spec := attn_module.get_kv_cache_spec(self.vllm_config):
                kv_cache_spec[layer_name] = spec
5116

5117
        return kv_cache_spec
5118

5119
    def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
5120
5121
5122
5123
5124
5125
5126
5127
        # This is a short term mitigation for issue mentioned in
        # https://github.com/vllm-project/vllm/issues/22754.
        # `tolist` would trigger a cuda wise stream sync, which
        # would block other copy ops from other cuda streams.
        # A cuda event sync would avoid such a situation. Since
        # this is in the critical path of every single model
        # forward loop, this has caused perf issue for a disagg
        # setup.
5128
        pinned = self.sampled_token_ids_pinned_cpu[: sampled_token_ids.shape[0]]
5129
5130
5131
        pinned.copy_(sampled_token_ids, non_blocking=True)
        self.transfer_event.record()
        self.transfer_event.synchronize()
5132
        return pinned.tolist()