gpu_model_runner.py 137 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import dataclasses
5
import gc
6
import time
7
from contextlib import contextmanager
8
from typing import TYPE_CHECKING, Any, Optional, Union, cast
9
10
11
12
13

import numpy as np
import torch
import torch.distributed
import torch.nn as nn
14
from tqdm import tqdm
15

16
import vllm.envs as envs
17
from vllm.attention import AttentionType, get_attn_backend
18
from vllm.attention.backends.abstract import AttentionBackend
19
from vllm.attention.layer import Attention
20
from vllm.compilation.counter import compilation_counter
21
from vllm.config import (CompilationLevel, VllmConfig,
22
                         get_layers_from_vllm_config, update_config)
23
from vllm.distributed.eplb.eplb_state import EplbState
24
25
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
                                          has_kv_transfer_group)
26
from vllm.distributed.parallel_state import (
27
    get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
28
    prepare_communication_buffer_for_model)
29
from vllm.forward_context import DPMetadata, set_forward_context
30
from vllm.logger import init_logger
31
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
32
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
33
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
34
35
36
37
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
                                                   supports_transcription)
from vllm.model_executor.models.interfaces_base import (
    VllmModelForPooling, is_pooling_model, is_text_generation_model)
38
39
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
40
from vllm.multimodal.utils import group_mm_inputs_by_modality
41
from vllm.pooling_params import PoolingParams
42
from vllm.sampling_params import SamplingType
43
from vllm.sequence import IntermediateTensors, PoolerOutput
44
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
45
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
46
                        GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
47
                        is_pin_memory_available, round_up, supports_dynamo)
48
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
49
50
from vllm.v1.attention.backends.utils import (
    AttentionMetadataBuilder, CommonAttentionMetadata,
51
    make_kv_sharing_fast_prefill_attention_metadata,
52
    make_local_attention_virtual_batches)
53
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
54
55
56
57
from vllm.v1.kv_cache_interface import (AttentionSpec,
                                        ChunkedLocalAttentionSpec,
                                        FullAttentionSpec, KVCacheConfig,
                                        KVCacheSpec, MambaSpec,
58
                                        SlidingWindowSpec)
59
60
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
                             ModelRunnerOutput)
61
from vllm.v1.pool.metadata import PoolingMetadata
62
from vllm.v1.sample.metadata import SamplingMetadata
63
from vllm.v1.sample.rejection_sampler import RejectionSampler
64
from vllm.v1.sample.sampler import Sampler
65
from vllm.v1.spec_decode.eagle import EagleProposer
66
from vllm.v1.spec_decode.medusa import MedusaProposer
67
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
68
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
69
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
70
71
from vllm.v1.worker.kv_connector_model_runner_mixin import (
    KVConnectorModelRunnerMixin)
72
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
73

74
from ..sample.logits_processor import LogitsProcessorManager
75
76
from .utils import (bind_kv_cache, gather_mm_placeholders,
                    initialize_kv_cache_for_kv_sharing,
77
                    sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
78

79
if TYPE_CHECKING:
80
    import xgrammar as xgr
81
    import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile  # noqa: E501
82

83
    from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
84
    from vllm.v1.core.sched.output import SchedulerOutput
85
86
else:
    xgr = LazyLoader("xgr", globals(), "xgrammar")
87
88
89
    xgr_torch_compile = LazyLoader(
        "xgr_torch_compile", globals(),
        "xgrammar.kernels.apply_token_bitmask_inplace_torch_compile")
90
91
92
93

logger = init_logger(__name__)


94
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
95
96
97

    def __init__(
        self,
98
        vllm_config: VllmConfig,
99
        device: torch.device,
100
    ):
101
102
103
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
104
        self.compilation_config = vllm_config.compilation_config
105
106
107
108
109
110
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.observability_config = vllm_config.observability_config
111

112
113
114
115
        from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
        set_cpu_offload_max_bytes(
            int(self.cache_config.cpu_offload_gb * 1024**3))

116
117
118
119
        model_config = self.model_config
        cache_config = self.cache_config
        scheduler_config = self.scheduler_config
        parallel_config = self.parallel_config
120
        self.device = device
121
122
123
124
125
126
127
128
        self.pin_memory = is_pin_memory_available()
        self.dtype = self.model_config.dtype
        if cache_config.cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        else:
            self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
                cache_config.cache_dtype]

129
        self.is_multimodal_model = model_config.is_multimodal_model
130
        self.is_pooling_model = model_config.pooler_config is not None
131
        self.is_encoder_only_model = False
132
133
        self.is_multimodal_raw_input_supported = (
            model_config.is_multimodal_raw_input_supported)
134
135
        self.max_model_len = model_config.max_model_len
        self.max_num_tokens = scheduler_config.max_num_batched_tokens
136
        self.max_num_reqs = scheduler_config.max_num_seqs
137
138

        # Model-related.
139
140
        self.num_query_heads = model_config.get_num_attention_heads(
            parallel_config)
141
        self.hidden_size = model_config.get_hidden_size()
142
        self.attention_chunk_size = model_config.attention_chunk_size
143

144
        self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
145

146
        # Multi-modal data support
147
        self.mm_registry = MULTIMODAL_REGISTRY
148
        self.uses_mrope = model_config.uses_mrope
149

150
151
152
        encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
            model_config=model_config,
            scheduler_config=scheduler_config,
153
            mm_registry=self.mm_registry,
154
155
156
        )
        self.max_num_encoder_input_tokens = encoder_compute_budget
        self.encoder_cache_size = encoder_cache_size
157

158
        # Sampler
159
        self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
160

161
162
163
164
165
166
167
        self.eplb_state: Optional[EplbState] = None
        """
        State of the expert parallelism load balancer.

        Will be lazily initialized when the model is loaded.
        """

168
        # Lazy initializations
169
        # self.model: nn.Module  # Set after load_model
170
        # Initialize in initialize_kv_cache
171
        self.kv_caches: list[torch.Tensor] = []
172
173
        self.attn_metadata_builders: list[AttentionMetadataBuilder] = []
        self.attn_backends: list[type[AttentionBackend]] = []
174
175
        # self.kv_cache_config: KVCacheConfig

176
        # req_id -> (input_id -> encoder_output)
177
        self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
178

179
        self.use_aux_hidden_state_outputs = False
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        # Set up speculative decoding.
        # NOTE(Jiayi): currently we put the entire draft model on
        # the last PP rank. This is not ideal if there are many
        # layers in the draft model.
        if self.speculative_config and get_pp_group().is_last_rank:
            if self.speculative_config.method == "ngram":
                self.drafter = NgramProposer(self.vllm_config)
            elif self.speculative_config.use_eagle():
                self.drafter = EagleProposer(self.vllm_config, self.device,
                                             self)  # type: ignore
                if self.speculative_config.method == "eagle3":
                    self.use_aux_hidden_state_outputs = True
            elif self.speculative_config.method == "medusa":
                self.drafter = MedusaProposer(
                    vllm_config=self.vllm_config,
                    device=self.device)  # type: ignore
            else:
                raise ValueError("Unknown speculative decoding method: "
                                 f"{self.speculative_config.method}")
            self.rejection_sampler = RejectionSampler()
200

201
        # Request states.
202
        self.requests: dict[str, CachedRequestState] = {}
203

204
205
206
207
208
209
210
211
212
        # Input Batch
        # NOTE(Chen): Ideally, we should initialize the input batch inside
        # `initialize_kv_cache` based on the kv cache config. However, as in
        # https://github.com/vllm-project/vllm/pull/18298, due to some unknown
        # reasons, we have to initialize the input batch before `load_model`,
        # quantization + weight offloading will fail otherwise. As a temporary
        # solution, we initialize the input batch here, and re-initialize it
        # in `initialize_kv_cache` if the block_sizes here is different from
        # the block_sizes in the kv cache config.
213
214
215
216
217
218
        self.input_batch = InputBatch(
            max_num_reqs=self.max_num_reqs,
            max_model_len=self.max_model_len,
            max_num_batched_tokens=self.max_num_tokens,
            device=self.device,
            pin_memory=self.pin_memory,
219
            vocab_size=self.model_config.get_vocab_size(),
220
            block_sizes=[self.cache_config.block_size],
221
            is_spec_decode=bool(self.vllm_config.speculative_config),
222
        )
223

224
225
226
227
228
        self.use_cuda_graph = (
            self.vllm_config.compilation_config.level
            == CompilationLevel.PIECEWISE
            and self.vllm_config.compilation_config.use_cudagraph
            and not self.model_config.enforce_eager)
229
        # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
230
231
232
233
        # The convention is different.
        # self.cudagraph_batch_sizes sorts in ascending order.
        # The batch sizes in the config are in descending order.
        self.cudagraph_batch_sizes = list(
234
235
236
            reversed(self.compilation_config.cudagraph_capture_sizes))

        self.full_cuda_graph = self.compilation_config.full_cuda_graph
237

238
        # Cache the device properties.
239
        self._init_device_properties()
240

241
242
243
244
        # Persistent buffers for CUDA graphs.
        self.input_ids = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int32,
                                     device=self.device)
245
246
247
        self.positions = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int64,
                                     device=self.device)
248
249
250
251
252
253
254
255
256
257
        self.query_start_loc = torch.zeros(self.max_num_reqs + 1,
                                           dtype=torch.int32,
                                           device=self.device)
        self.seq_lens = torch.zeros(self.max_num_reqs,
                                    dtype=torch.int32,
                                    device=self.device)
        self.slot_mapping = torch.zeros(self.max_num_tokens,
                                        dtype=torch.int64,
                                        device=self.device)

258
259
        # None in the first PP rank. The rest are set after load_model.
        self.intermediate_tensors: Optional[IntermediateTensors] = None
260
261

        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
262
        if self.uses_mrope:
Roger Wang's avatar
Roger Wang committed
263
264
265
266
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
267
268
269
270
271
272

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
Roger Wang's avatar
Roger Wang committed
273
            self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1),
274
275
                                               dtype=torch.int64,
                                               device=self.device)
Roger Wang's avatar
Roger Wang committed
276
277
278
279
280
            self.mrope_positions_cpu = torch.zeros(
                (3, self.max_num_tokens + 1),
                dtype=torch.int64,
                device="cpu",
                pin_memory=self.pin_memory)
281
            self.mrope_positions_np = self.mrope_positions_cpu.numpy()
282

283
284
285
        # Only relevant for models using ALiBi (e.g, MPT)
        self.use_alibi = check_use_alibi(model_config)

286
287
288
289
        self.inputs_embeds = torch.zeros(
            (self.max_num_tokens, self.hidden_size),
            dtype=self.dtype,
            device=self.device)
290

291
        # OPTIMIZATION: Cache the tensors rather than creating them every step.
292
        # Keep in int64 to avoid overflow with long context
293
        self.arange_np = np.arange(max(self.max_num_reqs + 1,
294
295
                                       self.max_model_len,
                                       self.max_num_tokens),
296
                                   dtype=np.int64)
297
298
299
        # NOTE(woosuk): These tensors are "stateless", i.e., they are literally
        # a faster version of creating a new tensor every time. Thus, we should
        # not make any assumptions about the values in these tensors.
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        self.input_ids_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int32,
                                         device="cpu",
                                         pin_memory=self.pin_memory)
        self.positions_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int64,
                                         device="cpu",
                                         pin_memory=self.pin_memory)
        self.positions_np = self.positions_cpu.numpy()
        self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
                                               dtype=torch.int32,
                                               device="cpu",
                                               pin_memory=self.pin_memory)
        self.query_start_loc_np = self.query_start_loc_cpu.numpy()
314
315
316
317
318
        self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
                                        dtype=torch.int32,
                                        device="cpu",
                                        pin_memory=self.pin_memory)
        self.seq_lens_np = self.seq_lens_cpu.numpy()
319

320
321
322
323
324
        # Layer pairings for cross-layer KV sharing.
        # If an Attention layer `layer_name` is in the keys of this dict, it
        # means this layer will perform attention using the keys and values
        # from the KV cache of `shared_kv_cache_layers[layer_name]`.
        self.shared_kv_cache_layers: dict[str, str] = {}
325
326
327
328
329
330
        self.kv_sharing_fast_prefill_eligible_layers: set[str] = set()

        self.kv_sharing_fast_prefill_logits_indices = None
        if self.cache_config.kv_sharing_fast_prefill:
            self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
                self.max_num_tokens, dtype=torch.int32, device=self.device)
331

332
    def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
333
334
        """
        Update the order of requests in the batch based on the attention
335
        backend's needs. For example, some attention backends (namely MLA) may
336
337
338
339
340
341
        want to separate requests based on if the attention computation will be
        compute-bound or memory-bound.

        Args:
            scheduler_output: The scheduler output.
        """
342
343
344
345
346
347
348
349
        # Attention free models have zero kv_cache_goups, however models
        # like Mamba are also attention free but use the kv_cache for
        # keeping its internal state. This is why we check the number
        # of kv_cache groups instead of solely checking
        # for self.model_config.is_attention_free.
        if len(self.kv_cache_config.kv_cache_groups) == 0:
            return

350
351
        self.attn_metadata_builders[0].reorder_batch(self.input_batch,
                                                     scheduler_output)
352
353
354
355
356

        # For models with multiple KV cache groups, the groups should agree on
        # the same order of requests. We ensure this by only allowing the first
        # group to reorder the batch and asserting that all other groups do not
        # reorder the batch.
357
358
359
        # TODO(tdoublep): make this more flexible so that any group can
        # re-order the batch (not only the first).
        # TODO(tdoublep): verify this during engine init instead of at runtime
360
        for i in range(1, len(self.kv_cache_config.kv_cache_groups)):
361
            batch_reordered = self.attn_metadata_builders[i].reorder_batch(
362
                self.input_batch, scheduler_output)
363
            assert not batch_reordered
364

365
366
367
368
369
370
371
372
373
374
375
    # Note: used for model runner override.
    def _init_device_properties(self) -> None:
        """Initialize attributes from torch.cuda.get_device_properties
        """
        self.device_properties = torch.cuda.get_device_properties(self.device)
        self.num_sms = self.device_properties.multi_processor_count

    # Note: used for model runner override.
    def _sync_device(self) -> None:
        torch.cuda.synchronize()

376
    def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
377
378
379
380
381
382
        """Update the cached states and the persistent batch with the scheduler
        output.

        The updated states are used by the `_prepare_inputs` function to create
        the input GPU tensors for the model.

383
384
        The SamplingMetadata is updated and copied to the GPU if there is a
        new/resumed/paused/finished request in the batch.
385
386
        """
        # Remove finished requests from the cached states.
387
388
        for req_id in scheduler_output.finished_req_ids:
            self.requests.pop(req_id, None)
389
            self.encoder_cache.pop(req_id, None)
390
391
392
393
394
395
396
        # Remove the finished requests from the persistent batch.
        # NOTE(woosuk): There could be an edge case where finished_req_ids and
        # scheduled_req_ids overlap. This happens when a request is aborted and
        # then resubmitted with the same ID. In this case, we treat them as two
        # distinct requests - clearing the cached states for the first request
        # and handling the second as a new request.
        for req_id in scheduler_output.finished_req_ids:
397
            self.input_batch.remove_request(req_id)
398
399
400
401
402
403
404
405

        # Free the cached encoder outputs.
        for req_id, input_id in scheduler_output.free_encoder_input_ids:
            encoder_outputs = self.encoder_cache.get(req_id)
            if encoder_outputs is not None:
                encoder_outputs.pop(input_id, None)
                if not encoder_outputs:
                    self.encoder_cache.pop(req_id, None)
406

407
408
409
410
411
412
413
414
415
416
417
418
419
        # Remove the unscheduled requests from the persistent batch.
        # NOTE(woosuk): The unscheduled requests are either preempted requests
        # or running requests that are not scheduled in this step. We remove
        # them from the persistent batch but keep their cached states since
        # they will be scheduled again sometime in the future.
        scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
        cached_req_ids = self.input_batch.req_id_to_index.keys()
        unscheduled_req_ids = cached_req_ids - scheduled_req_ids
        # NOTE(woosuk): The persistent batch optimization assumes that
        # consecutive batches contain mostly the same requests. If batches
        # have low request overlap (e.g., alternating between two distinct
        # sets of requests), this optimization becomes very inefficient.
        for req_id in unscheduled_req_ids:
420
            self.input_batch.remove_request(req_id)
421

422
        req_ids_to_add: list[str] = []
423
        # Add new requests to the cached states.
424
425
426
        for new_req_data in scheduler_output.scheduled_new_reqs:
            req_id = new_req_data.req_id
            sampling_params = new_req_data.sampling_params
427
            pooling_params = new_req_data.pooling_params
428

429
430
            if sampling_params and \
                sampling_params.sampling_type == SamplingType.RANDOM_SEED:
431
432
433
434
435
                generator = torch.Generator(device=self.device)
                generator.manual_seed(sampling_params.seed)
            else:
                generator = None

436
            if pooling_params:
437
                assert (task := pooling_params.task) is not None, (
438
439
440
                    "You did not set `task` in the API")

                model = cast(VllmModelForPooling, self.model)
441
                to_update = model.pooler.get_pooling_updates(task)
442
443
                to_update.apply(pooling_params)

444
445
            self.requests[req_id] = CachedRequestState(
                req_id=req_id,
446
447
448
                prompt_token_ids=new_req_data.prompt_token_ids,
                mm_inputs=new_req_data.mm_inputs,
                mm_positions=new_req_data.mm_positions,
449
                sampling_params=sampling_params,
450
                pooling_params=pooling_params,
451
                generator=generator,
452
453
                block_ids=new_req_data.block_ids,
                num_computed_tokens=new_req_data.num_computed_tokens,
454
                output_token_ids=[],
455
                lora_request=new_req_data.lora_request,
456
            )
457
458

            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
459
            if self.uses_mrope:
460
461
                image_grid_thw = []
                video_grid_thw = []
Roger Wang's avatar
Roger Wang committed
462
                second_per_grid_ts = []
463
464
                audio_feature_lengths = []
                use_audio_in_video = False
465
466
467
468
469
470
471
                for mm_input in self.requests[req_id].mm_inputs:
                    if mm_input.get("image_grid_thw") is not None:
                        image_grid_thw.extend(
                            mm_input["image_grid_thw"].tolist())
                    if mm_input.get("video_grid_thw") is not None:
                        video_grid_thw.extend(
                            mm_input["video_grid_thw"].tolist())
Roger Wang's avatar
Roger Wang committed
472
473
474
                    if mm_input.get("second_per_grid_ts") is not None:
                        second_per_grid_ts.extend(
                            mm_input["second_per_grid_ts"])
475
476
477
478
479
                    if mm_input.get("audio_feature_lengths") is not None:
                        audio_feature_lengths.extend(
                            mm_input["audio_feature_lengths"])
                    if mm_input.get("use_audio_in_video") is True:
                        use_audio_in_video = True
480
481
482
483
484
485
486

                hf_config = self.model_config.hf_config

                self.requests[req_id].mrope_positions, \
                    self.requests[req_id].mrope_position_delta = \
                    MRotaryEmbedding.get_input_positions_tensor(
                        self.requests[req_id].prompt_token_ids,
Roger Wang's avatar
Roger Wang committed
487
                        hf_config=hf_config,
488
489
                        image_grid_thw=image_grid_thw,
                        video_grid_thw=video_grid_thw,
Roger Wang's avatar
Roger Wang committed
490
                        second_per_grid_ts=second_per_grid_ts,
491
492
                        audio_feature_lengths=audio_feature_lengths,
                        use_audio_in_video=use_audio_in_video,
493
494
                    )

495
496
            req_ids_to_add.append(req_id)

497
        # Update the states of the running/resumed requests.
498
        is_last_rank = get_pp_group().is_last_rank
499
500
        req_data = scheduler_output.scheduled_cached_reqs
        for i, req_id in enumerate(req_data.req_ids):
501
            req_state = self.requests[req_id]
502
503
504
            num_computed_tokens = req_data.num_computed_tokens[i]
            new_block_ids = req_data.new_block_ids[i]
            resumed_from_preemption = req_data.resumed_from_preemption[i]
505

506
            # Update the cached states.
507
            req_state.num_computed_tokens = num_computed_tokens
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524

            if not is_last_rank:
                # When using PP, the scheduler sends the sampled tokens back,
                # because there's no direct communication between the first-
                # stage worker and the last-stage worker.
                new_token_ids = req_data.new_token_ids[i]
                # Add the sampled token(s) from the previous step (if any).
                # This doesn't include "unverified" tokens like spec tokens.
                num_new_tokens = (num_computed_tokens + len(new_token_ids) -
                                  req_state.num_tokens)
                if num_new_tokens == 1:
                    # Avoid slicing list in most common case.
                    req_state.output_token_ids.append(new_token_ids[-1])
                elif num_new_tokens > 0:
                    req_state.output_token_ids.extend(
                        new_token_ids[-num_new_tokens:])

525
            # Update the block IDs.
526
            if not resumed_from_preemption:
527
                # Append the new blocks to the existing block IDs.
528
529
530
                for block_ids, new_ids in zip(req_state.block_ids,
                                              new_block_ids):
                    block_ids.extend(new_ids)
531
532
533
            else:
                # The request is resumed from preemption.
                # Replace the existing block IDs with the new ones.
534
                req_state.block_ids = new_block_ids
535
536
537
538
539
540
541
542
543
544
545

            req_index = self.input_batch.req_id_to_index.get(req_id)
            if req_index is None:
                # The request is not in the persistent batch.
                # The request was either preempted and resumed later, or was not
                # scheduled in the previous step and needs to be added again.
                req_ids_to_add.append(req_id)
                continue

            # Update the persistent batch.
            self.input_batch.num_computed_tokens_cpu[req_index] = (
546
                num_computed_tokens)
547
            self.input_batch.block_table.append_row(new_block_ids, req_index)
548
549
550
551
552
553
554

            # For the last rank, we don't need to update the token_ids_cpu
            # because the sampled tokens are already cached.
            if not is_last_rank:
                # Add new_token_ids to token_ids_cpu.
                start_token_index = num_computed_tokens
                end_token_index = num_computed_tokens + len(new_token_ids)
555
                self.input_batch.token_ids_cpu[
556
557
558
559
560
                    req_index,
                    start_token_index:end_token_index] = new_token_ids
                self.input_batch.num_tokens_no_spec[
                    req_index] = end_token_index
                self.input_batch.num_tokens[req_index] = end_token_index
561

562
563
564
565
566
567
568
569
570
571
572
573
            # Add spec_token_ids to token_ids_cpu.
            spec_token_ids = (
                scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
            if spec_token_ids:
                num_spec_tokens = len(spec_token_ids)
                start_index = self.input_batch.num_tokens_no_spec[req_index]
                end_token_index = start_index + num_spec_tokens
                self.input_batch.token_ids_cpu[
                    req_index, start_index:end_token_index] = spec_token_ids
                # NOTE(woosuk): `num_tokens` here may include spec tokens.
                self.input_batch.num_tokens[req_index] += num_spec_tokens

574
575
576
577
        # Add the new or resumed requests to the persistent batch.
        # The smaller empty indices are filled first.
        for req_id in req_ids_to_add:
            req_state = self.requests[req_id]
578
            self.input_batch.add_request(req_state)
579

580
581
582
583
584
585
        # Condense the batched states if there are gaps left by removed requests
        self.input_batch.condense()
        # Allow attention backend to reorder the batch, potentially
        self._may_reorder_batch(scheduler_output)
        # Refresh batch metadata with any pending updates.
        self.input_batch.refresh_metadata()
586

587
588
589
590
591
592
593
    def _init_model_kwargs_for_multimodal_model(
        self,
        scheduler_output: Optional["SchedulerOutput"] = None,
        num_reqs: int = -1,
    ) -> dict[str, Any]:

        model_kwargs: dict[str, Any] = {}
594
        if self.is_multimodal_raw_input_supported:
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
            # This model requires the raw multimodal data in input.
            if scheduler_output:
                multi_modal_kwargs_list = []
                for req in scheduler_output.scheduled_new_reqs:
                    req_mm_inputs = req.mm_inputs
                    if not isinstance(req_mm_inputs, list):
                        req_mm_inputs = list(req_mm_inputs)
                    multi_modal_kwargs_list.extend(req_mm_inputs)
                multi_modal_kwargs = MultiModalKwargs.batch(
                    multi_modal_kwargs_list)
            else:
                # The only case where SchedulerOutput is None is for
                # a dummy run let's get some dummy data.
                dummy_data = [
                    self.mm_registry.get_decoder_dummy_data(
                        model_config=self.model_config,
                        seq_len=1).multi_modal_data for i in range(num_reqs)
                ]
                multi_modal_kwargs = MultiModalKwargs.batch(dummy_data)

            model_kwargs.update(multi_modal_kwargs)

        return model_kwargs

619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
    def _get_cumsum_and_arange(
        self,
        num_tokens: np.ndarray,
        cumsum_dtype: Optional[np.dtype] = None,
    ) -> tuple[np.ndarray, np.ndarray]:
        """Get the cumulative sum and batched arange of the given array.
        # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
        # Equivalent to but faster than:
        # np.concatenate([np.arange(n) for n in num_tokens])
        """
        # Step 1. [2, 5, 3] -> [2, 7, 10]
        cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype)
        total_num_tokens = cu_num_tokens[-1]
        # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
        cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens)
        # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        arange = self.arange_np[:total_num_tokens] - cumsums_offsets

        return cu_num_tokens, arange

639
    def _prepare_inputs(
640
641
        self,
        scheduler_output: "SchedulerOutput",
642
643
644
    ) -> tuple[dict[str,
                    Any], bool, torch.Tensor, Optional[SpecDecodeMetadata],
               np.ndarray, Optional[CommonAttentionMetadata]]:
645
646
647
648
649
650
651
        """
        :return: tuple[
            attn_metadata: layer-to-attention_metadata mapping,
            attention_cuda_graphs: whether attention can run in cudagraph
            logits_indices, spec_decode_metadata
        ]
        """
652
653
654
655
656
657
658
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        assert total_num_scheduled_tokens > 0
        num_reqs = self.input_batch.num_reqs
        assert num_reqs > 0

        # OPTIMIZATION: Start copying the block table first.
        # This way, we can overlap the copy with the following CPU operations.
659
        self.input_batch.block_table.commit_block_table(num_reqs)
660
661

        # Get the number of scheduled tokens for each request.
662
663
664
665
        req_ids = self.input_batch.req_ids
        tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
        num_scheduled_tokens = np.array(tokens, dtype=np.int32)
        max_num_scheduled_tokens = max(tokens)
666
667
668

        # Get request indices.
        # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
669
670
        req_indices = np.repeat(self.arange_np[:num_reqs],
                                num_scheduled_tokens)
671

672
673
674
675
        # cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
        # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        cu_num_tokens, arange = self._get_cumsum_and_arange(
            num_scheduled_tokens)
676
677

        # Get positions.
678
        positions_np = self.positions_np[:total_num_scheduled_tokens]
679
680
681
682
        np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
               arange,
               out=positions_np)

683
684
        # Calculate M-RoPE positions.
        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
685
        if self.uses_mrope:
686
687
            self._calc_mrope_positions(scheduler_output)

688
689
690
691
        # Get token indices.
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
        # where M is the max_model_len.
692
693
        token_indices = (positions_np +
                         req_indices * self.input_batch.token_ids_cpu.shape[1])
694

695
696
697
698
        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
        torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
699
                           0,
700
701
                           torch.from_numpy(token_indices),
                           out=self.input_ids_cpu[:total_num_scheduled_tokens])
702

703
704
705
706
        self.input_batch.block_table.compute_slot_mapping(
            req_indices, positions_np)
        self.input_batch.block_table.commit_slot_mapping(
            total_num_scheduled_tokens)
707
708

        # Prepare the attention metadata.
709
        self.query_start_loc_np[0] = 0
710
        self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
711

712
713
714
        self.seq_lens_np[:num_reqs] = (
            self.input_batch.num_computed_tokens_cpu[:num_reqs] +
            num_scheduled_tokens)
715
716
717
718

        # Copy the tensors to the GPU.
        self.input_ids[:total_num_scheduled_tokens].copy_(
            self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
719
        if self.uses_mrope:
720
721
722
723
724
725
726
727
728
            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
            self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
                self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
                non_blocking=True)
        else:
            # Common case (1D positions)
            self.positions[:total_num_scheduled_tokens].copy_(
                self.positions_cpu[:total_num_scheduled_tokens],
                non_blocking=True)
729

730
731
732
733
734
        self.query_start_loc[:num_reqs + 1].copy_(
            self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
        self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
                                       non_blocking=True)

735
        # Fill unused with 0 for full cuda graph mode.
736
        self.seq_lens[num_reqs:].fill_(0)
737
738
739
740
        # Note: pad query_start_loc to be non-decreasing, as kernels
        # like FlashAttention requires that
        self.query_start_loc[num_reqs + 1:].fill_(
            self.query_start_loc_cpu[num_reqs].item())
741
742

        query_start_loc = self.query_start_loc[:num_reqs + 1]
743
744

        spec_decode_common_attn_metadata = None
745

746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
        use_spec_decode = len(
            scheduler_output.scheduled_spec_decode_tokens) > 0
        if not use_spec_decode:
            # NOTE(woosuk): Due to chunked prefills, the batch may contain
            # partial requests. While we should not sample any token
            # from these partial requests, we do so for simplicity.
            # We will ignore the sampled tokens from the partial requests.
            # TODO: Support prompt logprobs.
            logits_indices = query_start_loc[1:] - 1
            spec_decode_metadata = None
        else:
            # Get the number of draft tokens for each request.
            # Iterate over the dictionary rather than all requests since not all
            # requests have draft tokens.
            num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
            for req_id, draft_token_ids in (
                    scheduler_output.scheduled_spec_decode_tokens.items()):
                req_idx = self.input_batch.req_id_to_index[req_id]
                num_draft_tokens[req_idx] = len(draft_token_ids)

            spec_decode_metadata = self._calc_spec_decode_metadata(
                num_draft_tokens, cu_num_tokens)
            logits_indices = spec_decode_metadata.logits_indices

        logits_indices_padded = None
        if self.cache_config.kv_sharing_fast_prefill:
            assert self.kv_sharing_fast_prefill_logits_indices is not None
            num_logits = logits_indices.shape[0]
            assert num_logits > 0
            self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_(
                logits_indices)
            # There might have leftover indices in logits_indices[num_logits:]
            # from previous iterations, whose values may be greater than the
            # batch size in the current iteration. To ensure indices are always
            # valid, we fill the padded indices with the last index.
            self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_(
                logits_indices[-1].item())
            if (self.use_cuda_graph
                    and num_logits <= self.cudagraph_batch_sizes[-1]):
                # Use piecewise CUDA graphs.
                # Add padding to the batch size.
                num_logits_padded = self.vllm_config.pad_for_cudagraph(
                    num_logits)
            else:
                num_logits_padded = num_logits
            logits_indices_padded = (
                self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]
            )

795
        attn_metadata: dict[str, Any] = {}
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810

        # Prepare encoder attention metadata separately
        # (encoder layers are not in KV cache groups)
        if self.is_encoder_only_model:
            common_attn_metadata, encoder_attn_metadata = \
                self._build_encoder_only_attn_metadata(
                scheduler_output)

            # Add encoder attention metadata for all encoder layers
            attention_layers = get_layers_from_vllm_config(
                self.vllm_config, Attention)
            for layer_name, attn_module in attention_layers.items():
                if attn_module.attn_type == AttentionType.ENCODER_ONLY:
                    attn_metadata[layer_name] = encoder_attn_metadata

811
812
813
814
815
        # Prepare the attention metadata for each KV cache group and make layers
        # in the same group share the same metadata.
        for kv_cache_group_id, kv_cache_group_spec in enumerate(
                self.kv_cache_config.kv_cache_groups):

816
817
818
            blk_table = self.input_batch.block_table[kv_cache_group_id]
            blk_table_tensor = blk_table.get_device_tensor()[:num_reqs]
            slot_mapping = blk_table.slot_mapping[:total_num_scheduled_tokens]
819
820
821
822
823

            # Fill unused with -1. Needed for reshape_and_cache in full cuda
            # graph mode.
            blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1)

824
825
826
827
828
829
830
831
832
833
834
835
            common_attn_metadata = CommonAttentionMetadata(
                query_start_loc=self.query_start_loc[:num_reqs + 1],
                query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
                seq_lens=self.seq_lens[:num_reqs],
                seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
                num_computed_tokens_cpu=self.input_batch.
                num_computed_tokens_cpu_tensor[:num_reqs],
                num_reqs=num_reqs,
                num_actual_tokens=total_num_scheduled_tokens,
                max_query_len=max_num_scheduled_tokens,
                block_table_tensor=blk_table_tensor,
                slot_mapping=slot_mapping,
836
                causal=True,
837
838
839
840
841
842
            )

            if self.speculative_config and \
                spec_decode_common_attn_metadata is None:
                spec_decode_common_attn_metadata = common_attn_metadata

843
844
845
846
847
848
            if isinstance(kv_cache_group_spec.kv_cache_spec,
                          ChunkedLocalAttentionSpec):
                common_attn_metadata = make_local_attention_virtual_batches(
                    kv_cache_group_spec.kv_cache_spec.attention_chunk_size,
                    common_attn_metadata, self.cache_config.block_size)

849
850
            # Prepare for cascade attention if enabled & beneficial.
            common_prefix_len = 0
851
            builder = self.attn_metadata_builders[kv_cache_group_id]
852
853
854
            if self.cascade_attn_enabled:
                common_prefix_len = self._compute_cascade_attn_prefix_len(
                    num_scheduled_tokens,
855
856
857
                    scheduler_output.
                    num_common_prefix_blocks[kv_cache_group_id],
                    kv_cache_group_spec.kv_cache_spec,
858
                    builder,
859
                )
860

861
862
863
864
865
            attn_metadata_i = (builder.build(
                common_prefix_len=common_prefix_len,
                common_attn_metadata=common_attn_metadata,
            ))

866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
            fast_prefill_metadata = attn_metadata_i
            if (self.cache_config.kv_sharing_fast_prefill
                    and self.kv_sharing_fast_prefill_eligible_layers):
                # Dynamically create a a dataclass type that inherits
                # from attention metadata type but includes additional
                # fields logits_indices_padded and num_logits_indices
                # which are required for prefill truncation
                fast_prefill_metadata_type = (
                    make_kv_sharing_fast_prefill_attention_metadata(
                        metadata_cls=type(attn_metadata_i), ))
                fast_prefill_metadata = fast_prefill_metadata_type(
                    **dataclasses.asdict(attn_metadata_i),
                    logits_indices_padded=logits_indices_padded,
                    num_logits_indices=logits_indices.size(0),
                )

882
            for layer_name in kv_cache_group_spec.layer_names:
883
884
885
886
887
                if (self.cache_config.kv_sharing_fast_prefill and layer_name
                        in self.kv_sharing_fast_prefill_eligible_layers):
                    attn_metadata[layer_name] = fast_prefill_metadata
                    continue

888
                attn_metadata[layer_name] = attn_metadata_i
889

890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
            # Hack for now to fix chunked local attention + no hybrid kv cache
            # manager we can remove this once
            # https://github.com/vllm-project/vllm/pull/21588
            # is merged (i.e. properly handle different attention backends for
            # the same kv_cache_spec)
            if self.attention_chunk_size is not None \
                    and self.scheduler_config.disable_hybrid_kv_cache_manager:
                if not hasattr(self, "local_attention_layers"):
                    self.local_attention_layers = []
                    attn_layers = get_layers_from_vllm_config(
                        self.vllm_config, Attention)
                    for layer_name, attn_module in attn_layers.items():
                        if attn_module.use_irope:
                            self.local_attention_layers.append(layer_name)

                local_attn_metadata_i = (builder.build(
                    common_prefix_len=0,
                    common_attn_metadata=make_local_attention_virtual_batches(
                        self.attention_chunk_size, common_attn_metadata,
                        self.cache_config.block_size),
                ))

                for layer_name in self.local_attention_layers:
                    attn_metadata[layer_name] = local_attn_metadata_i

915
916
917
918
        attention_cuda_graphs = all(
            b.can_run_in_cudagraph(common_attn_metadata)
            for b in self.attn_metadata_builders)

919
920
921
922
        # Hot-Swap lora model
        if self.lora_config:
            self.set_active_loras(self.input_batch, num_scheduled_tokens)

923
        return (attn_metadata, attention_cuda_graphs, logits_indices,
924
925
                spec_decode_metadata, num_scheduled_tokens,
                spec_decode_common_attn_metadata)
926

927
928
929
930
    def _compute_cascade_attn_prefix_len(
        self,
        num_scheduled_tokens: np.ndarray,
        num_common_prefix_blocks: int,
931
932
        kv_cache_spec: KVCacheSpec,
        attn_metadata_builder: AttentionMetadataBuilder,
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
    ) -> int:
        """Compute the length of the common prefix for cascade attention.

        NOTE(woosuk): The common prefix length returned by this function
        represents the length used specifically for cascade attention, not the
        actual number of tokens shared between requests. When cascade attention
        is disabled (use_cascade=False), this function returns 0 even if
        requests share common tokens. Additionally, the common prefix length is
        truncated to a multiple of the block size and may be further truncated
        due to implementation details explained below.

        Args:
            num_scheduled_tokens: Number of tokens scheduled per request.
            num_common_prefix_blocks: Number of shared KV cache blocks.

        Returns:
            int: Length of common prefix in tokens.
        """
951
        common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
        if common_prefix_len == 0:
            # Common case.
            return 0

        # NOTE(woosuk): Cascade attention uses two attention kernels: one
        # for the common prefix and the other for the rest. For the first
        # kernel, we concatenate all the query tokens (possibly from
        # different requests) and treat them as if they are from the same
        # request. Then, we use bi-directional attention to process the
        # common prefix in the KV cache. Importantly, this means that the
        # first kernel does not do any masking.

        # Consider the following example:
        # Request 1's input query: [D, E, X]
        # Request 1's kv cache: [A, B, C, D, E, X]
        # Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
        # Request 2's input query: [E, Y]
        # Request 2's kv cache: [A, B, C, D, E, Y]
        # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])

        # If we use [A, B, C, D, E] as the common prefix, then the
        # first kernel will compute the bi-directional attention between
        # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
        # However, this is wrong because D in Request 1 should not attend to
        # E in the common prefix (i.e., we need masking).
        # To avoid this, [A, B, C, D] should be the common prefix.
        # That is, the common prefix should be capped by the minimum
        # num_computed_tokens among the requests, and plus one to include
        # the first token of the query.

        # In practice, we use [A, B, C] as the common prefix, instead of
        # [A, B, C, D] (i.e., the common prefix is capped by the minimum
        # num_computed_tokens, without plus one).
        # This is because of an implementation detail: We want to always
        # use two kernels for cascade attention. Let's imagine:
        # Request 3's input query: [D]
        # Request 3's kv cache: [A, B, C, D]
989
        # Request 3's num_computed_tokens: 3 (i.e., [A, B, C])
990
991
992
993
994
995
996
997
998
999
        # If we use [A, B, C, D] as the common prefix for Request 1-3,
        # then Request 3 will be processed only by the first kernel,
        # and the second kernel will get an empty input. While this is not
        # a fundamental problem, our current implementation does not support
        # this case.
        num_reqs = len(num_scheduled_tokens)
        common_prefix_len = min(
            common_prefix_len,
            self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
        # common_prefix_len should be a multiple of the block size.
1000
1001
1002
1003
1004
        common_prefix_len = (common_prefix_len // kv_cache_spec.block_size *
                             kv_cache_spec.block_size)
        use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or
                              (isinstance(kv_cache_spec, FullAttentionSpec)
                               and kv_cache_spec.sliding_window is not None))
1005
1006
1007
1008
        use_local_attention = (
            isinstance(kv_cache_spec, ChunkedLocalAttentionSpec)
            or (isinstance(kv_cache_spec, FullAttentionSpec)
                and kv_cache_spec.attention_chunk_size is not None))
1009
1010
        assert isinstance(kv_cache_spec, AttentionSpec)
        use_cascade = attn_metadata_builder.use_cascade_attention(
1011
1012
1013
            common_prefix_len=common_prefix_len,
            query_lens=num_scheduled_tokens,
            num_query_heads=self.num_query_heads,
1014
            num_kv_heads=kv_cache_spec.num_kv_heads,
1015
            use_alibi=self.use_alibi,
1016
            use_sliding_window=use_sliding_window,
1017
            use_local_attention=use_local_attention,
1018
1019
1020
1021
            num_sms=self.num_sms,
        )
        return common_prefix_len if use_cascade else 0

1022
1023
    def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
        mrope_pos_ptr = 0
1024
        for index, req_id in enumerate(self.input_batch.req_ids):
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
            req = self.requests[req_id]
            assert req.mrope_positions is not None

            num_computed_tokens = \
                self.input_batch.num_computed_tokens_cpu[index]
            num_scheduled_tokens = \
                scheduler_output.num_scheduled_tokens[req_id]
            num_prompt_tokens = len(req.prompt_token_ids)

            if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
                prompt_part_len = max(0,
                                      num_prompt_tokens - num_computed_tokens)
                completion_part_len = max(
                    0, num_scheduled_tokens - prompt_part_len)
            else:
                prompt_part_len = num_scheduled_tokens
                completion_part_len = 0

            assert num_scheduled_tokens == prompt_part_len + completion_part_len

            if prompt_part_len > 0:
                # prompt's mrope_positions are pre-computed
                dst_start = mrope_pos_ptr
                dst_end = mrope_pos_ptr + prompt_part_len
                src_start = num_computed_tokens
                src_end = num_computed_tokens + prompt_part_len

                self.mrope_positions_cpu[:, dst_start:dst_end] = \
                    req.mrope_positions[:,src_start:src_end]

                mrope_pos_ptr += prompt_part_len

            if completion_part_len > 0:
                # compute completion's mrope_positions on-the-fly
                dst_start = mrope_pos_ptr
                dst_end = mrope_pos_ptr + completion_part_len

1062
1063
1064
1065
1066
1067
1068
                MRotaryEmbedding.get_next_input_positions_tensor(
                    out=self.mrope_positions_np,
                    out_offset=dst_start,
                    mrope_position_delta=req.mrope_position_delta,
                    context_len=num_computed_tokens + prompt_part_len,
                    num_new_tokens=completion_part_len,
                )
1069
1070
1071

                mrope_pos_ptr += completion_part_len

1072
1073
    def _calc_spec_decode_metadata(
        self,
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
        num_draft_tokens: np.ndarray,
        cu_num_scheduled_tokens: np.ndarray,
    ) -> SpecDecodeMetadata:
        # Inputs:
        # cu_num_scheduled_tokens:  [  4, 104, 107, 207, 209]
        # num_draft_tokens:         [  3,   0,   2,   0,   1]
        # Outputs:
        # cu_num_draft_tokens:      [  3,   3,   5,   5,   6]
        # logits_indices:           [  0,   1,   2,   3, 103, 104, 105, 106,
        #                            206, 207, 208]
        # target_logits_indices:    [  0,   1,   2,   5,   6,   9]
        # bonus_logits_indices:     [  3,   4,   7,   8,  10]

        # Compute the logits indices.
        # [4, 1, 3, 1, 2]
        num_sampled_tokens = num_draft_tokens + 1
1090
1091
1092
1093
1094
1095

        # Step 1. cu_num_sampled_tokens: [4, 5, 8, 9, 11]
        # arange: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
        cu_num_sampled_tokens, arange = self._get_cumsum_and_arange(
            num_sampled_tokens, cumsum_dtype=np.int32)
        # Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
1096
1097
        logits_indices = np.repeat(
            cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens)
1098
        # Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
1099
1100
1101
1102
1103
1104
        logits_indices += arange

        # Compute the bonus logits indices.
        bonus_logits_indices = cu_num_sampled_tokens - 1

        # Compute the draft logits indices.
1105
1106
1107
1108
        # cu_num_draft_tokens: [3, 3, 5, 5, 6]
        # arange: [0, 1, 2, 0, 1, 0]
        cu_num_draft_tokens, arange = self._get_cumsum_and_arange(
            num_draft_tokens, cumsum_dtype=np.int32)
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
        # [0, 0, 0, 5, 5, 9]
        target_logits_indices = np.repeat(
            cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens)
        # [0, 1, 2, 5, 6, 9]
        target_logits_indices += arange

        # TODO: Optimize the CPU -> GPU copy.
        cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
            self.device, non_blocking=True)
        logits_indices = torch.from_numpy(logits_indices).to(self.device,
                                                             non_blocking=True)
        target_logits_indices = torch.from_numpy(target_logits_indices).to(
            self.device, non_blocking=True)
        bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
1123
1124
            self.device, non_blocking=True)

1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
        # Compute the draft token ids.
        # draft_token_indices:      [  1,   2,   3, 105, 106, 208]
        draft_token_ids = self.input_ids[logits_indices]
        draft_token_ids = draft_token_ids[target_logits_indices + 1]

        metadata = SpecDecodeMetadata(
            draft_token_ids=draft_token_ids,
            num_draft_tokens=num_draft_tokens.tolist(),
            cu_num_draft_tokens=cu_num_draft_tokens,
            target_logits_indices=target_logits_indices,
            bonus_logits_indices=bonus_logits_indices,
            logits_indices=logits_indices,
        )
        return metadata

1140
    def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
1141
1142
1143
1144
1145
        scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
        if not scheduled_encoder_inputs:
            return

        # Batch the multi-modal inputs.
1146
1147
        mm_inputs = list[MultiModalKwargs]()
        req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
1148
1149
        for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
            req_state = self.requests[req_id]
1150
1151
1152
1153
1154

            for mm_input_id in encoder_input_ids:
                mm_inputs.append(req_state.mm_inputs[mm_input_id])
                req_ids_pos.append(
                    (req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166

        # Batch mm inputs as much as we can: if a request in the batch has
        # multiple modalities or a different modality than the previous one,
        # we process it separately to preserve item order.
        # FIXME(ywang96): This is a hacky way to deal with multiple modalities
        # in the same batch while still being able to benefit from batching
        # multimodal inputs. The proper solution should be reordering the
        # encoder outputs.
        grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)

        encoder_outputs = []
        for grouped_mm_inputs in grouped_mm_inputs_list:
1167
1168
            batched_mm_inputs = MultiModalKwargs.batch(
                grouped_mm_inputs, pin_memory=self.pin_memory)
1169
1170
1171
1172
            batched_mm_inputs = MultiModalKwargs.as_kwargs(
                batched_mm_inputs,
                device=self.device,
            )
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183

            # Run the encoder.
            # `curr_group_outputs` is either of the following:
            # 1. A tensor of shape (num_items, feature_size, hidden_size)
            # in case feature_size is fixed across all multimodal items.
            # 2. A list or tuple (length: num_items) of tensors, each of shape
            # (feature_size, hidden_size) in case the feature size is dynamic
            # depending on the input multimodal items.
            curr_group_outputs = self.model.get_multimodal_embeddings(
                **batched_mm_inputs)

1184
1185
1186
1187
1188
            sanity_check_mm_encoder_outputs(
                curr_group_outputs,
                expected_num_items=len(grouped_mm_inputs),
            )

1189
1190
            for output in curr_group_outputs:
                encoder_outputs.append(output)
1191
1192

        # Cache the encoder outputs.
1193
1194
1195
1196
        for (req_id, input_id, pos_info), output in zip(
                req_ids_pos,
                encoder_outputs,
        ):
1197
1198
1199
            if req_id not in self.encoder_cache:
                self.encoder_cache[req_id] = {}

1200
1201
1202
1203
1204
1205
            self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
                output,
                is_embed=pos_info.is_embed,
            )

    def _gather_mm_embeddings(
1206
1207
        self,
        scheduler_output: "SchedulerOutput",
1208
    ) -> list[torch.Tensor]:
1209
        mm_embeds: list[torch.Tensor] = []
1210
        for req_id in self.input_batch.req_ids:
1211
1212
1213
1214
1215
1216
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
                req_id]
            req_state = self.requests[req_id]
            num_computed_tokens = req_state.num_computed_tokens
            mm_positions = req_state.mm_positions
            for i, pos_info in enumerate(mm_positions):
1217
1218
                start_pos = pos_info.offset
                num_encoder_tokens = pos_info.length
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239

                # The encoder output is needed if the two ranges overlap:
                # [num_computed_tokens,
                #  num_computed_tokens + num_scheduled_tokens) and
                # [start_pos, start_pos + num_encoder_tokens)
                if start_pos >= num_computed_tokens + num_scheduled_tokens:
                    # The encoder output is not needed in this step.
                    break
                if start_pos + num_encoder_tokens <= num_computed_tokens:
                    # The encoder output is already processed and stored
                    # in the decoder's KV cache.
                    continue

                start_idx = max(num_computed_tokens - start_pos, 0)
                end_idx = min(
                    num_computed_tokens - start_pos + num_scheduled_tokens,
                    num_encoder_tokens)
                assert start_idx < end_idx
                assert req_id in self.encoder_cache
                assert i in self.encoder_cache[req_id]
                encoder_output = self.encoder_cache[req_id][i]
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249

                if (is_embed := pos_info.is_embed) is not None:
                    is_embed = is_embed[start_idx:end_idx]

                mm_embeds_item = gather_mm_placeholders(
                    encoder_output[start_idx:end_idx],
                    is_embed=is_embed,
                )
                mm_embeds.append(mm_embeds_item)
        return mm_embeds
1250

1251
1252
1253
    def get_model(self) -> nn.Module:
        return self.model

1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
    def get_supported_generation_tasks(self) -> list[GenerationTask]:
        model = self.get_model()
        supported_tasks = list[GenerationTask]()

        if is_text_generation_model(model):
            supported_tasks.append("generate")

        if supports_transcription(model):
            if model.supports_transcription_only:
                return ["transcription"]

            supported_tasks.append("transcription")

        return supported_tasks

1269
1270
1271
1272
1273
    def get_supported_pooling_tasks(self) -> list[PoolingTask]:
        model = self.get_model()
        if not is_pooling_model(model):
            return []

1274
        return list(model.pooler.get_supported_tasks())
1275

1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        tasks = list[SupportedTask]()

        if self.model_config.runner_type == "generate":
            tasks.extend(self.get_supported_generation_tasks())
        if self.model_config.runner_type == "pooling":
            tasks.extend(self.get_supported_pooling_tasks())

        return tuple(tasks)

1286
1287
1288
1289
1290
1291
1292
1293
1294
    def apply_grammar_bitmask(
        self,
        scheduler_output: "SchedulerOutput",
        logits: torch.Tensor,
    ):
        grammar_bitmask = scheduler_output.grammar_bitmask
        if grammar_bitmask is None:
            return

1295
1296
1297
1298
1299
1300
1301
1302
1303
        # We receive the structured output bitmask from the scheduler,
        # compacted to contain bitmasks only for structured output requests.
        # The order of the requests in the bitmask is not guaranteed to be the
        # same as the order of the requests in the gpu runner's batch. We need
        # to sort the bitmask to match the order of the requests used here.

        # Get the batch indices of the structured output requests.
        # Keep track of the number of speculative tokens scheduled for every
        # request in the batch, as the logit indices are offset by this amount.
1304
        struct_out_req_batch_indices: dict[str, int] = {}
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
        cumulative_offset = 0
        seq = sorted(self.input_batch.req_id_to_index.items(),
                     key=lambda x: x[1])
        for req_id, batch_index in seq:
            logit_index = batch_index + cumulative_offset
            cumulative_offset += len(
                scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
            if req_id in scheduler_output.structured_output_request_ids:
                struct_out_req_batch_indices[req_id] = logit_index

        out_indices = []

        # Reorder the bitmask to match the order of the requests in the batch.
        sorted_bitmask = np.zeros_like(grammar_bitmask,
                                       shape=(logits.shape[0],
                                              grammar_bitmask.shape[1]))
        cumulative_index = 0
        seq = sorted(scheduler_output.structured_output_request_ids.items(),
                     key=lambda x: x[1])
        for req_id, _ in seq:
            logit_index = struct_out_req_batch_indices[req_id]
            num_spec_tokens = len(
                scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
            for i in range(1 + num_spec_tokens):
                sorted_bitmask[logit_index + i] = \
                    grammar_bitmask[cumulative_index + i]
                out_indices.append(logit_index + i)
            cumulative_index += 1 + num_spec_tokens
        grammar_bitmask = sorted_bitmask
1334

1335
1336
        # Serialization of np.ndarray is much more efficient than a tensor,
        # so we receive it in that format.
1337
1338
        grammar_bitmask = torch.from_numpy(grammar_bitmask)

1339
1340
1341
1342
        # Force use of the torch.compile implementation from xgrammar to work
        # around issues with the Triton kernel in concurrent structured output
        # scenarios. See PR #19565 and issues #19493, #18376 for details.
        xgr_torch_compile.apply_token_bitmask_inplace_torch_compile(
1343
1344
            logits,
            grammar_bitmask.to(self.device, non_blocking=True),
1345
            indices=out_indices,
1346
1347
        )

1348
1349
1350
1351
1352
1353
1354
    def sync_and_slice_intermediate_tensors(
            self, num_tokens: int, intermediate_tensors: IntermediateTensors,
            sync_self: bool) -> IntermediateTensors:

        assert self.intermediate_tensors is not None

        tp = self.vllm_config.parallel_config.tensor_parallel_size
1355
        enabled_sp = self.compilation_config.pass_config. \
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
            enable_sequence_parallelism
        if enabled_sp:
            # When sequence parallelism is enabled, we always pad num_tokens
            # to be a multiple of tensor_parallel_size (tp) earlier
            assert num_tokens % tp == 0
        is_residual_scattered = tp > 1 and enabled_sp \
            and num_tokens % tp == 0

        # When sequence parallelism is enabled, the "residual" tensor is sharded
        # across tensor parallel ranks, so each rank only needs its own slice.
        if sync_self:
            assert intermediate_tensors is not None
            for k, v in intermediate_tensors.items():
1369
                is_scattered = k == "residual" and is_residual_scattered
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
                copy_len = num_tokens // tp if is_scattered else \
                    num_tokens
                self.intermediate_tensors[k][:copy_len].copy_(
                    v[:copy_len], non_blocking=True)

        return IntermediateTensors({
            k:
            v[:num_tokens // tp]
            if k == "residual" and is_residual_scattered else v[:num_tokens]
            for k, v in self.intermediate_tensors.items()
        })

1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
    def eplb_step(self,
                  is_dummy: bool = False,
                  is_profile: bool = False) -> None:
        """
        Step for the EPLB (Expert Parallelism Load Balancing) state.
        """
        if not self.parallel_config.enable_eplb:
            return

        assert self.eplb_state is not None
        assert is_mixture_of_experts(self.model)
        self.eplb_state.step(
            self.model,
            is_dummy,
            is_profile,
            log_stats=self.parallel_config.eplb_log_balancedness,
        )

1400
1401
    def get_dp_padding(self,
                       num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
1402
1403
        dp_size = self.vllm_config.parallel_config.data_parallel_size
        dp_rank = self.vllm_config.parallel_config.data_parallel_rank
1404
1405
1406
1407
1408
1409
1410
1411
1412

        # For DP: Don't pad when setting enforce_eager.
        # This lets us set enforce_eager on the prefiller in a P/D setup and
        # still use CUDA graphs (enabled by this padding) on the decoder.
        #
        # TODO(tms) : There are many cases where padding is enabled for
        # prefills, causing unnecessary and excessive padding of activations.

        if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
1413
            # Early exit.
1414
            return 0, None
1415
1416
1417
1418

        num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
            num_tokens, dp_size, dp_rank)
        max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
1419
1420
1421
1422
1423
        num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] *
                                                dp_size,
                                                device="cpu",
                                                dtype=torch.int32)
        return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
1424

1425
1426
1427
1428
1429
    def _pool(
        self,
        hidden_states: torch.Tensor,
        num_scheduled_tokens: int,
        num_scheduled_tokens_np: np.ndarray,
1430
1431
        finished_sending: Optional[set[str]],
        finished_recving: Optional[set[str]],
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
    ) -> ModelRunnerOutput:
        assert self.input_batch.num_reqs ==\
            len(self.input_batch.pooling_params), \
        "Either all or none of the requests in" \
        " a batch must be pooling request"

        extracted_hidden_states = list(
            torch.split(hidden_states[:num_scheduled_tokens],
                        num_scheduled_tokens_np.tolist()))

        pooling_metadata = self.input_batch.pooling_metadata

        raw_pooler_output = self.model.pooler(
            hidden_states=extracted_hidden_states,
            pooling_metadata=pooling_metadata)

        pooler_output: list[Optional[torch.Tensor]] = []
        seq_lens = self.seq_lens[:self.input_batch.num_reqs]
        for raw_output, seq_len, prompt_len in zip(
                raw_pooler_output, seq_lens, pooling_metadata.prompt_lens):

            if seq_len == prompt_len:
                pooler_output.append(raw_output.data.cpu())
            else:
                pooler_output.append(None)

        return ModelRunnerOutput(
            req_ids=self.input_batch.req_ids,
            req_id_to_index=self.input_batch.req_id_to_index,
            sampled_token_ids=[],
            spec_token_ids=None,
            logprobs=None,
            prompt_logprobs_dict={},
            pooler_output=pooler_output,
1466
1467
            finished_sending=finished_sending,
            finished_recving=finished_recving,
1468
1469
        )

1470
1471
1472
1473
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
1474
        intermediate_tensors: Optional[IntermediateTensors] = None,
1475
    ) -> Union[ModelRunnerOutput, IntermediateTensors]:
1476
        self._update_states(scheduler_output)
1477
        if not scheduler_output.total_num_scheduled_tokens:
1478
1479
1480
            if not has_kv_transfer_group():
                # Return empty ModelRunnerOutput if there's no work to do.
                return EMPTY_MODEL_RUNNER_OUTPUT
Robert Shaw's avatar
Robert Shaw committed
1481

1482
1483
            return self.kv_connector_no_forward(scheduler_output,
                                                self.vllm_config)
1484
1485

        # Prepare the decoder inputs.
1486
        (attn_metadata, attention_cuda_graphs, logits_indices,
1487
1488
1489
         spec_decode_metadata, num_scheduled_tokens_np,
         spec_decode_common_attn_metadata) = (
             self._prepare_inputs(scheduler_output))
1490

1491
1492
1493
1494
1495
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        if (self.use_cuda_graph
                and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
            # Use piecewise CUDA graphs.
            # Add padding to the batch size.
1496
            num_input_tokens = self.vllm_config.pad_for_cudagraph(
1497
1498
1499
                num_scheduled_tokens)
        else:
            # Eager mode.
1500
1501
1502
            # Pad tokens to multiple of tensor_parallel_size when
            # enabled collective fusion for SP
            tp_size = self.vllm_config.parallel_config.tensor_parallel_size
1503
            if self.compilation_config.pass_config. \
1504
1505
1506
1507
                enable_sequence_parallelism and tp_size > 1:
                num_input_tokens = round_up(num_scheduled_tokens, tp_size)
            else:
                num_input_tokens = num_scheduled_tokens
1508

1509
        # Padding for DP
1510
1511
        num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
        num_input_tokens += num_pad
1512

1513
1514
1515
1516
1517
1518
1519
1520
1521
        # _prepare_inputs may reorder the batch, so we must gather multi
        # modal outputs after that to ensure the correct order
        if self.is_multimodal_model:
            # Run the multimodal encoder if any.
            self._execute_mm_encoder(scheduler_output)
            mm_embeds = self._gather_mm_embeddings(scheduler_output)
        else:
            mm_embeds = []

1522
        if self.is_multimodal_model and get_pp_group().is_first_rank:
1523
1524
1525
1526
            # NOTE(woosuk): To unify token ids and soft tokens (vision
            # embeddings), we always use embeddings (rather than token ids)
            # as input to the multimodal model, even when the input is text.
            input_ids = self.input_ids[:num_scheduled_tokens]
1527
1528
1529

            model_kwargs = self._init_model_kwargs_for_multimodal_model(
                scheduler_output=scheduler_output)
1530
1531
1532
1533
            inputs_embeds = self.model.get_input_embeddings(
                input_ids=input_ids,
                multimodal_embeddings=mm_embeds or None,
            )
1534

1535
1536
1537
1538
            # TODO(woosuk): Avoid the copy. Optimize.
            self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
            input_ids = None
1539
        else:
1540
1541
1542
1543
1544
1545
            # For text-only models, we use token ids as input.
            # While it is possible to use embeddings as input just like the
            # multimodal models, it is not desirable for performance since
            # then the embedding layer is not included in the CUDA graph.
            input_ids = self.input_ids[:num_input_tokens]
            inputs_embeds = None
1546
            model_kwargs = {}
1547
1548
1549
1550
        if self.uses_mrope:
            positions = self.mrope_positions[:, :num_input_tokens]
        else:
            positions = self.positions[:num_input_tokens]
1551

1552
1553
1554
        if get_pp_group().is_first_rank:
            intermediate_tensors = None
        else:
1555
1556
            intermediate_tensors = self.sync_and_slice_intermediate_tensors(
                num_input_tokens, intermediate_tensors, True)
1557

1558
1559
1560
1561
1562
        # Some attention backends only support CUDA Graphs in pure decode.
        # If attention doesn't support CUDA Graphs for this batch, but we
        # compiled with full CUDA graphs, we have to skip them entirely.
        skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs

1563
        # Run the model.
1564
        # Use persistent buffers for CUDA graphs.
1565
1566
1567
1568
1569
1570
1571
        with set_forward_context(
                attn_metadata,
                self.vllm_config,
                num_tokens=num_input_tokens,
                num_tokens_across_dp=num_tokens_across_dp,
                skip_cuda_graphs=skip_cuda_graphs,
        ):
Robert Shaw's avatar
Robert Shaw committed
1572
1573
1574
            self.maybe_setup_kv_connector(scheduler_output)

            model_output = self.model(
1575
                input_ids=input_ids,
1576
                positions=positions,
1577
                intermediate_tensors=intermediate_tensors,
1578
                inputs_embeds=inputs_embeds,
1579
1580
1581
1582
                **MultiModalKwargs.as_kwargs(
                    model_kwargs,
                    device=self.device,
                ),
1583
            )
1584

Robert Shaw's avatar
Robert Shaw committed
1585
            self.maybe_wait_for_kv_save()
1586
1587
            finished_sending, finished_recving = (
                self.get_finished_kv_transfers(scheduler_output))
Robert Shaw's avatar
Robert Shaw committed
1588

1589
        if self.use_aux_hidden_state_outputs:
Robert Shaw's avatar
Robert Shaw committed
1590
            hidden_states, aux_hidden_states = model_output
1591
        else:
Robert Shaw's avatar
Robert Shaw committed
1592
            hidden_states = model_output
1593
1594
            aux_hidden_states = None

1595
1596
1597
1598
1599
1600
1601
        # Broadcast PP output for external_launcher (torchrun)
        # to make sure we are synced across pp ranks
        # TODO: Support overlapping mirco-batches
        # https://github.com/vllm-project/vllm/issues/18019
        broadcast_pp_output = \
            self.parallel_config.distributed_executor_backend \
            == "external_launcher" and len(get_pp_group().ranks) > 0
1602
        if not get_pp_group().is_last_rank:
1603
            # For mid-pipeline stages, return the hidden states.
1604
            if not broadcast_pp_output:
1605
                if finished_sending or finished_recving:
1606
1607
                    hidden_states.finished_sending = finished_sending
                    hidden_states.finished_recving = finished_recving
1608
1609
1610
1611
1612
1613
                return hidden_states
            assert isinstance(hidden_states, IntermediateTensors)
            get_pp_group().send_tensor_dict(hidden_states.tensors,
                                            all_gather_group=get_tp_group())
            logits = None
        else:
1614
1615
            if self.input_batch.pooling_params:
                return self._pool(hidden_states, num_scheduled_tokens,
1616
                                  num_scheduled_tokens_np, finished_sending,
1617
                                  finished_recving)
1618

1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
            sample_hidden_states = hidden_states[logits_indices]
            logits = self.model.compute_logits(sample_hidden_states, None)
        if broadcast_pp_output:
            model_output_broadcast_data = {
                "logits": logits.contiguous(),
            } if logits is not None else {}
            model_output_broadcast_data = get_pp_group().broadcast_tensor_dict(
                model_output_broadcast_data, src=len(get_pp_group().ranks) - 1)
            assert model_output_broadcast_data is not None
            logits = model_output_broadcast_data["logits"]
1629

1630
1631
1632
1633
        # Apply structured output bitmasks if present
        if scheduler_output.grammar_bitmask is not None:
            self.apply_grammar_bitmask(scheduler_output, logits)

1634
        # Sample the next token and get logprobs if needed.
1635
        sampling_metadata = self.input_batch.sampling_metadata
1636
        if spec_decode_metadata is None:
1637
            sampler_output = self.sampler(
1638
1639
1640
1641
                logits=logits,
                sampling_metadata=sampling_metadata,
            )
        else:
1642
1643
1644
1645
            # When indexing with a tensor (bonus_logits_indices), PyTorch
            # creates a new tensor with separate storage from the original
            # logits tensor. This means any in-place operations on bonus_logits
            # won't affect the original logits tensor.
1646
            assert logits is not None
1647
            bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
1648
            sampler_output = self.sampler(
1649
                logits=bonus_logits,
1650
1651
1652
                sampling_metadata=sampling_metadata,
            )
            bonus_token_ids = sampler_output.sampled_token_ids
1653

1654
1655
1656
            # Just like `bonus_logits`, `target_logits` is a new tensor with
            # separate storage from the original `logits` tensor. Therefore,
            # it is safe to update `target_logits` in place.
1657
            target_logits = logits[spec_decode_metadata.target_logits_indices]
1658
            output_token_ids = self.rejection_sampler(
1659
                spec_decode_metadata,
1660
                None,  # draft_probs
1661
                target_logits,
1662
                bonus_token_ids,
1663
1664
                sampling_metadata,
            )
1665
            sampler_output.sampled_token_ids = output_token_ids
1666

1667
1668
1669
1670
        num_nans_in_logits = {}
        if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
            num_nans_in_logits = self._get_nans_in_logits(logits)

1671
1672
        # TODO(woosuk): The following loop can be slow since it iterates over
        # the requests one by one. Optimize.
1673
1674
        discard_sampled_tokens_req_indices = []
        for i, req_id in enumerate(self.input_batch.req_ids):
1675
1676
1677
            req_state = self.requests[req_id]
            seq_len = (req_state.num_computed_tokens +
                       scheduler_output.num_scheduled_tokens[req_id])
1678
            if seq_len < req_state.num_tokens:
1679
                # Ignore the sampled token for partial prefills.
1680
                # Rewind the generator state as if the token was not sampled.
1681
                # This relies on cuda-specific torch-internal impl details
1682
1683
1684
1685
1686
1687
                generator = self.input_batch.generators.get(i)
                if generator is not None:
                    generator.set_offset(generator.get_offset() - 4)
                # Record the index of the request that should not be sampled,
                # so that we could clear the sampled tokens before returning.
                discard_sampled_tokens_req_indices.append(i)
1688

1689
1690
        # NOTE: GPU -> CPU Sync happens here.
        # Move as many CPU operations as possible before this sync point.
1691
1692
1693
1694
1695
1696
        logprobs_tensors = sampler_output.logprobs_tensors
        logprobs_lists = logprobs_tensors.tolists() \
            if logprobs_tensors is not None else None

        # Compute prompt logprobs if needed.
        prompt_logprobs_dict = self._get_prompt_logprobs_dict(
1697
            hidden_states[:num_scheduled_tokens],
1698
1699
1700
            scheduler_output,
        )

1701
        # Get the valid generated tokens.
1702
1703
1704
        sampled_token_ids = sampler_output.sampled_token_ids
        max_gen_len = sampled_token_ids.shape[-1]
        if max_gen_len == 1:
1705
            # No spec decode tokens.
1706
1707
            valid_sampled_token_ids = sampled_token_ids.tolist()
        else:
1708
            # Includes spec decode tokens.
1709
            valid_sampled_token_ids = self.rejection_sampler.parse_output(
1710
1711
1712
                sampled_token_ids,
                self.input_batch.vocab_size,
            )
1713
1714
1715
        # Mask out the sampled tokens that should not be sampled.
        for i in discard_sampled_tokens_req_indices:
            valid_sampled_token_ids[i].clear()
1716

1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
        # Cache the sampled tokens in the model runner, so that the scheduler
        # doesn't need to send them back.
        # NOTE(woosuk): As an exception, when using PP, the scheduler sends
        # the sampled tokens back, because there's no direct communication
        # between the first-stage worker and the last-stage worker.
        for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
            if not sampled_ids:
                continue

            start_idx = self.input_batch.num_tokens_no_spec[req_idx]
            end_idx = start_idx + len(sampled_ids)
            assert end_idx <= self.max_model_len, (
                "Sampled token IDs exceed the max model length. "
                f"Total number of tokens: {end_idx} > max_model_len: "
                f"{self.max_model_len}")

            self.input_batch.token_ids_cpu[req_idx,
                                           start_idx:end_idx] = sampled_ids
            self.input_batch.num_tokens_no_spec[req_idx] = end_idx
            self.input_batch.num_tokens[req_idx] = end_idx
            req_id = self.input_batch.req_ids[req_idx]
            req_state = self.requests[req_id]
            req_state.output_token_ids.extend(sampled_ids)

1741
        if not self.speculative_config:
1742
            # Speculative decoding is not enabled.
1743
            spec_token_ids = None
1744
        else:
1745
            assert spec_decode_common_attn_metadata is not None
1746
1747
1748
1749
1750
1751
1752
1753
            spec_token_ids = self.propose_draft_token_ids(
                scheduler_output,
                valid_sampled_token_ids,
                sampling_metadata,
                hidden_states,
                sample_hidden_states,
                aux_hidden_states,
                spec_decode_metadata,
1754
                spec_decode_common_attn_metadata,
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
            )

        self.eplb_step()

        return ModelRunnerOutput(
            req_ids=self.input_batch.req_ids,
            req_id_to_index=self.input_batch.req_id_to_index,
            sampled_token_ids=valid_sampled_token_ids,
            spec_token_ids=spec_token_ids,
            logprobs=logprobs_lists,
            prompt_logprobs_dict=prompt_logprobs_dict,
            pooler_output=[],
1767
1768
            finished_sending=finished_sending,
            finished_recving=finished_recving,
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
            num_nans_in_logits=num_nans_in_logits,
        )

    def propose_draft_token_ids(
        self,
        scheduler_output: "SchedulerOutput",
        sampled_token_ids: list[list[int]],
        sampling_metadata: SamplingMetadata,
        hidden_states: torch.Tensor,
        sample_hidden_states: torch.Tensor,
        aux_hidden_states: Optional[torch.Tensor],
        spec_decode_metadata: Optional[SpecDecodeMetadata],
1781
        common_attn_metadata: CommonAttentionMetadata,
1782
1783
1784
    ) -> list[list[int]]:
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        if self.speculative_config.method == "ngram":
1785
            assert isinstance(self.drafter, NgramProposer)
1786
1787
            spec_token_ids = self.propose_ngram_draft_token_ids(
                sampled_token_ids)
1788
1789
        elif self.speculative_config.method == "medusa":
            assert isinstance(self.drafter, MedusaProposer)
1790
1791
            if sample_hidden_states.shape[0] == len(sampled_token_ids):
                # The input to the target model does not include draft tokens.
1792
1793
1794
1795
1796
1797
                hidden_states = sample_hidden_states
            else:
                indices = []
                offset = 0
                for num_draft, tokens in zip(
                        spec_decode_metadata.num_draft_tokens,
1798
                        sampled_token_ids):
1799
1800
                    indices.append(offset + len(tokens) - 1)
                    offset += num_draft + 1
1801
                indices = torch.tensor(indices, device=self.device)
1802
1803
1804
1805
1806
1807
                hidden_states = sample_hidden_states[indices]

            spec_token_ids = self.drafter.propose(
                target_hidden_states=hidden_states,
                sampling_metadata=sampling_metadata,
            )
1808
        elif self.speculative_config.use_eagle():
1809
1810
1811
            assert isinstance(self.drafter, EagleProposer)
            # TODO(woosuk): Refactor the loop.
            next_token_ids: list[int] = []
1812
            for i, token_ids in enumerate(sampled_token_ids):
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
                if token_ids:
                    # Common case.
                    next_token_id = token_ids[-1]
                else:
                    # Partial prefill (rare case).
                    # Get the next token id from the request state.
                    req_id = self.input_batch.req_ids[i]
                    req_state = self.requests[req_id]
                    seq_len = (req_state.num_computed_tokens +
                               scheduler_output.num_scheduled_tokens[req_id])
                    next_token_id = req_state.get_token_id(seq_len)
                next_token_ids.append(next_token_id)
1825
1826
1827
            next_token_ids = torch.tensor(next_token_ids,
                                          dtype=torch.int32,
                                          device=self.device)
Jiayi Yao's avatar
Jiayi Yao committed
1828

1829
1830
1831
            if spec_decode_metadata is None:
                # input_ids can be None for multimodal models.
                target_token_ids = self.input_ids[:num_scheduled_tokens]
1832
1833
                # TODO(woosuk): Support M-RoPE.
                target_positions = self.positions[:num_scheduled_tokens]
1834
                if self.use_aux_hidden_state_outputs:
1835
1836
1837
                    target_hidden_states = torch.cat(
                        [h[:num_scheduled_tokens] for h in aux_hidden_states],
                        dim=-1)
1838
1839
                else:
                    target_hidden_states = hidden_states[:num_scheduled_tokens]
1840
1841
1842
1843
            else:
                # TODO(woosuk): Refactor this.
                num_draft_tokens = spec_decode_metadata.num_draft_tokens
                num_rejected_tokens = [
1844
                    n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
1845
1846
                    for i, n in enumerate(num_draft_tokens)
                ]
1847
1848
1849
1850
1851
1852
                num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens,
                                                       dtype=torch.int32)
                common_attn_metadata, token_indices =\
                    self.drafter.prepare_inputs(
                    common_attn_metadata, num_rejected_tokens_cpu)

1853
                target_token_ids = self.input_ids[token_indices]
1854
1855
                # TODO(woosuk): Support M-RoPE.
                target_positions = self.positions[token_indices]
1856
                if self.use_aux_hidden_state_outputs:
1857
1858
                    target_hidden_states = torch.cat(
                        [h[token_indices] for h in aux_hidden_states], dim=-1)
1859
1860
                else:
                    target_hidden_states = hidden_states[token_indices]
1861
            draft_token_ids = self.drafter.propose(
1862
1863
1864
1865
1866
                target_token_ids=target_token_ids,
                target_positions=target_positions,
                target_hidden_states=target_hidden_states,
                next_token_ids=next_token_ids,
                sampling_metadata=sampling_metadata,
1867
                common_attn_metadata=common_attn_metadata,
1868
1869
            )
            spec_token_ids = draft_token_ids.tolist()
1870
        return spec_token_ids
1871

1872
    def propose_ngram_draft_token_ids(
1873
        self,
1874
1875
        sampled_token_ids: list[list[int]],
    ) -> list[list[int]]:
1876
        # TODO(woosuk): Optimize.
1877
        draft_token_ids: list[list[int]] = []
1878
1879
1880
        for i, sampled_ids in enumerate(sampled_token_ids):
            num_sampled_ids = len(sampled_ids)
            if not num_sampled_ids:
1881
1882
1883
1884
                # Skip speculative decoding.
                draft_token_ids.append([])
                continue

1885
1886
            # Skip requests that require sampling parameters that are not
            # supported with speculative decoding.
1887
            req_id = self.input_batch.req_ids[i]
1888
            if req_id in self.input_batch.spec_decode_unsupported_reqs:
1889
1890
1891
                draft_token_ids.append([])
                continue

1892
1893
            num_tokens = self.input_batch.num_tokens_no_spec[i]
            if num_tokens >= self.max_model_len:
1894
1895
1896
1897
                # Skip requests that have already reached the max model length.
                draft_token_ids.append([])
                continue

1898
            drafter_output = self.drafter.propose(
1899
                self.input_batch.token_ids_cpu[i, :num_tokens])
1900
1901
1902
1903
1904
1905
            if drafter_output is None or len(drafter_output) == 0:
                draft_token_ids.append([])
            else:
                draft_token_ids.append(drafter_output.tolist())
        return draft_token_ids

1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
    def update_config(self, overrides: dict[str, Any]) -> None:
        allowed_config_names = {"load_config", "model_config"}
        for config_name, config_overrides in overrides.items():
            assert config_name in allowed_config_names, \
                f"Config `{config_name}` not supported. " \
                f"Allowed configs: {allowed_config_names}"
            config = getattr(self, config_name)
            new_config = update_config(config, config_overrides)
            setattr(self, config_name, new_config)

1916
1917
1918
1919
1920
    def load_model(self, eep_scale_up: bool = False) -> None:
        """
        Args:
            eep_scale_up: the model loading is for elastic EP scale up.
        """
1921
        logger.info("Starting to load model %s...", self.model_config.model)
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
        if eep_scale_up:
            from vllm.distributed.parallel_state import get_ep_group
            num_local_physical_experts = torch.empty(1,
                                                     dtype=torch.int32,
                                                     device="cpu")
            torch.distributed.broadcast(num_local_physical_experts,
                                        group=get_ep_group().cpu_group,
                                        group_src=0)
            num_local_physical_experts = int(num_local_physical_experts.item())
            new_ep_size = get_ep_group().world_size
            global_expert_load, old_global_expert_indices = (
                EplbState.recv_state())
            num_logical_experts = global_expert_load.shape[1]
            self.parallel_config.num_redundant_experts = (
                num_local_physical_experts * new_ep_size - num_logical_experts)
            assert old_global_expert_indices.shape[
                1] % num_local_physical_experts == 0
            old_ep_size = old_global_expert_indices.shape[
                1] // num_local_physical_experts
            rank_mapping = {
                old_ep_rank: old_ep_rank
                for old_ep_rank in range(old_ep_size)
            }
        else:
            global_expert_load = None
            old_global_expert_indices = None
            rank_mapping = None

1950
        with DeviceMemoryProfiler() as m:
1951
            time_before_load = time.perf_counter()
1952
            model_loader = get_model_loader(self.load_config)
1953
1954
1955
            logger.info("Loading model from scratch...")
            self.model = model_loader.load_model(
                vllm_config=self.vllm_config, model_config=self.model_config)
1956
1957
1958
1959
1960
1961
            if self.lora_config:
                self.model = self.load_lora_model(self.model,
                                                  self.model_config,
                                                  self.scheduler_config,
                                                  self.lora_config,
                                                  self.device)
1962
1963
1964
            if hasattr(self, "drafter"):
                logger.info("Loading drafter model...")
                self.drafter.load_model(self.model)
1965
1966
1967
            if self.use_aux_hidden_state_outputs:
                self.model.set_aux_hidden_state_layers(
                    self.model.get_eagle3_aux_hidden_state_layers())
1968
            time_after_load = time.perf_counter()
1969
        self.model_memory_usage = m.consumed_memory
1970
1971
        logger.info("Model loading took %.4f GiB and %.6f seconds",
                    self.model_memory_usage / GiB_bytes,
1972
                    time_after_load - time_before_load)
1973
        prepare_communication_buffer_for_model(self.model)
1974

1975
1976
1977
1978
1979
1980
1981
1982
        if is_mixture_of_experts(
                self.model) and self.parallel_config.enable_eplb:
            logger.info("EPLB is enabled for model %s.",
                        self.model_config.model)
            self.eplb_state = EplbState.build(
                self.model,
                self.device,
                self.parallel_config,
1983
1984
1985
                global_expert_load,
                old_global_expert_indices,
                rank_mapping,
1986
1987
            )

1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
        if (
            self.vllm_config.compilation_config.level == \
                CompilationLevel.DYNAMO_AS_IS and supports_dynamo()
        ):
            backend = self.vllm_config.compilation_config.init_backend(
                self.vllm_config)
            compilation_counter.dynamo_as_is_count += 1
            self.model.compile(
                fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
                backend=backend)

1999
2000
2001
2002
2003
2004
2005
    def reload_weights(self) -> None:
        assert getattr(self, "model", None) is not None, \
            "Cannot reload weights before model is loaded."
        model_loader = get_model_loader(self.load_config)
        logger.info("Reloading weights inplace...")
        model_loader.load_weights(self.model, model_config=self.model_config)

2006
2007
2008
2009
2010
2011
2012
    def save_tensorized_model(
        self,
        tensorizer_config: "TensorizerConfig",
    ) -> None:
        TensorizerLoader.save_model(
            self.model,
            tensorizer_config=tensorizer_config,
2013
            model_config=self.model_config,
2014
2015
        )

2016
2017
2018
2019
    def _get_prompt_logprobs_dict(
        self,
        hidden_states: torch.Tensor,
        scheduler_output: "SchedulerOutput",
2020
    ) -> dict[str, Optional[LogprobsTensors]]:
2021
2022
2023
2024
        num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
        if not num_prompt_logprobs_dict:
            return {}

2025
        in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
2026
        prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040

        # Since prompt logprobs are a rare feature, prioritize simple,
        # maintainable loop over optimal performance.
        completed_prefill_reqs = []
        for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():

            num_tokens = scheduler_output.num_scheduled_tokens[req_id]

            # Get metadata for this request.
            request = self.requests[req_id]
            num_prompt_tokens = len(request.prompt_token_ids)
            prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
                self.device, non_blocking=True)

2041
2042
2043
2044
2045
2046
2047
2048
2049
            # Set up target LogprobsTensors object.
            logprobs_tensors = in_progress_dict.get(req_id)
            if not logprobs_tensors:
                # Create empty logprobs CPU tensors for the entire prompt.
                # If chunked, we'll copy in slice by slice.
                logprobs_tensors = LogprobsTensors.empty_cpu(
                    num_prompt_tokens - 1, num_prompt_logprobs + 1)
                in_progress_dict[req_id] = logprobs_tensors

2050
            # Determine number of logits to retrieve.
2051
2052
            start_idx = request.num_computed_tokens
            start_tok = start_idx + 1
2053
            num_remaining_tokens = num_prompt_tokens - start_tok
2054
            if num_tokens <= num_remaining_tokens:
2055
                # This is a chunk, more tokens remain.
2056
2057
2058
                # In the == case, there are no more prompt logprobs to produce
                # but we want to defer returning them to the next step where we
                # have new generated tokens to return.
2059
2060
2061
2062
2063
                num_logits = num_tokens
            else:
                # This is the last chunk of prompt tokens to return.
                num_logits = num_remaining_tokens
                completed_prefill_reqs.append(req_id)
2064
2065
2066
2067
2068
2069
2070
                prompt_logprobs_dict[req_id] = logprobs_tensors

            if num_logits <= 0:
                # This can happen for the final chunk if we prefilled exactly
                # (num_prompt_tokens - 1) tokens for this request in the prior
                # step. There are no more prompt logprobs to produce.
                continue
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085

            # Get the logits corresponding to this req's prompt tokens.
            # If this is a partial request (i.e. chunked prefill),
            # then there is prompt logprob generated for each index.
            req_idx = self.input_batch.req_id_to_index[req_id]
            offset = self.query_start_loc_np[req_idx].item()
            prompt_hidden_states = hidden_states[offset:offset + num_logits]
            logits = self.model.compute_logits(prompt_hidden_states, None)

            # Get the "target" tokens for each index. For prompt at index i,
            # the token at prompt index i+1 is the "sampled" token we want
            # to gather the logprob for.
            tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]

            # Compute prompt logprobs.
2086
2087
            logprobs = self.sampler.compute_logprobs(logits)
            token_ids, logprobs, ranks = self.sampler.gather_logprobs(
2088
2089
2090
                logprobs, num_prompt_logprobs, tgt_token_ids)

            # Transfer GPU->CPU async.
2091
2092
2093
2094
2095
2096
2097
            chunk_slice = slice(start_idx, start_idx + num_logits)
            logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
                token_ids, non_blocking=True)
            logprobs_tensors.logprobs[chunk_slice].copy_(logprobs,
                                                         non_blocking=True)
            logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
                ranks, non_blocking=True)
2098
2099
2100
2101
2102

        # Remove requests that have completed prefill from the batch
        # num_prompt_logprobs_dict.
        for req_id in completed_prefill_reqs:
            del num_prompt_logprobs_dict[req_id]
2103
            del in_progress_dict[req_id]
2104
2105

        # Must synchronize the non-blocking GPU->CPU transfers.
2106
        if prompt_logprobs_dict:
2107
            self._sync_device()
2108
2109
2110

        return prompt_logprobs_dict

2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
    def _get_nans_in_logits(
        self,
        logits: Optional[torch.Tensor],
    ) -> dict[str, int]:
        try:
            if logits is None:
                return {req_id: 0 for req_id in self.input_batch.req_ids}

            num_nans_in_logits = {}
            num_nans_for_index = logits.isnan().sum(dim=-1).cpu().numpy()
            for req_id in self.input_batch.req_ids:
                req_index = self.input_batch.req_id_to_index[req_id]
                num_nans_in_logits[req_id] = (
                    int(num_nans_for_index[req_index])
                    if num_nans_for_index is not None
                    and req_index < logits.shape[0] else 0)
            return num_nans_in_logits
        except IndexError:
            return {}

2131
2132
2133
2134
2135
2136
    @contextmanager
    def maybe_randomize_inputs(self, input_ids: torch.Tensor):
        """
        Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
        This is to help balance expert-selection
         - during profile_run
2137
         - during DP rank dummy run
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
        """
        dp_size = self.vllm_config.parallel_config.data_parallel_size
        randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1
        if not randomize_inputs:
            yield
        else:
            import functools

            @functools.cache
            def rand_input_ids() -> torch.Tensor:
                return torch.randint_like(
                    self.input_ids,
                    low=0,
                    high=self.model_config.get_vocab_size(),
                    dtype=input_ids.dtype)

            logger.debug("Randomizing dummy data for DP Rank")
            input_ids.copy_(rand_input_ids()[:input_ids.size(0)],
                            non_blocking=True)
            yield
            input_ids.fill_(0)

2160
2161
2162
2163
    @torch.inference_mode()
    def _dummy_run(
        self,
        num_tokens: int,
2164
        capture_attn_cudagraph: bool = False,
2165
2166
        skip_eplb: bool = False,
        is_profile: bool = False,
2167
    ) -> tuple[torch.Tensor, torch.Tensor]:
2168

2169
        # Padding for DP
2170
2171
        num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
        num_tokens += num_pad
2172

2173
2174
2175
2176
2177
        # Set num_scheduled_tokens based on num_tokens and max_num_seqs
        # for dummy run with LoRA so that the num_reqs collectively
        # has num_tokens in total.
        assert num_tokens <= self.scheduler_config.max_num_batched_tokens
        max_num_reqs = self.scheduler_config.max_num_seqs
2178
        num_reqs = min(num_tokens, max_num_reqs)
2179
2180
2181
2182
2183
2184
2185
        min_tokens_per_req = num_tokens // num_reqs
        num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
        num_scheduled_tokens_list[-1] += num_tokens % num_reqs
        assert sum(num_scheduled_tokens_list) == num_tokens
        assert len(num_scheduled_tokens_list) == num_reqs
        num_scheduled_tokens = np.array(num_scheduled_tokens_list,
                                        dtype=np.int32)
2186

2187
2188
2189
2190
        attn_metadata: Optional[dict[str, Any]] = None
        if capture_attn_cudagraph:
            attn_metadata = {}

2191
2192
2193
2194
2195
            # Make sure max_model_len is used at the graph capture time.
            self.seq_lens_np[:num_reqs] = self.max_model_len
            self.seq_lens_np[num_reqs:] = 0
            self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
                                           non_blocking=True)
2196

2197
2198
            for kv_cache_group_id, kv_cache_group_spec in enumerate(
                    self.kv_cache_config.kv_cache_groups):
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
                common_attn_metadata = CommonAttentionMetadata(
                    query_start_loc=self.query_start_loc[:num_reqs + 1],
                    query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
                                                                 1],
                    seq_lens=self.seq_lens[:num_reqs],
                    seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
                    num_computed_tokens_cpu=self.input_batch.
                    num_computed_tokens_cpu_tensor[:num_reqs],
                    num_reqs=num_reqs,
                    num_actual_tokens=num_tokens,
                    max_query_len=num_tokens,
                    block_table_tensor=self.input_batch.block_table[
                        kv_cache_group_id].get_device_tensor()[:num_reqs],
                    slot_mapping=self.input_batch.
2213
2214
                    block_table[kv_cache_group_id].slot_mapping[:num_tokens],
                    causal=True)
2215
2216
2217
2218

                attn_metadata_i = self.attn_metadata_builders[
                    kv_cache_group_id].build_for_cudagraph_capture(
                        common_attn_metadata)
2219
2220
                for layer_name in kv_cache_group_spec.layer_names:
                    attn_metadata[layer_name] = attn_metadata_i
2221

2222
2223
2224
2225
        with self.maybe_dummy_run_with_lora(self.lora_config,
                                            num_scheduled_tokens):
            model = self.model
            if self.is_multimodal_model:
2226
2227
                model_kwargs = self._init_model_kwargs_for_multimodal_model(
                    num_reqs=num_reqs)
2228
2229
2230
2231
2232
                input_ids = None
                inputs_embeds = self.inputs_embeds[:num_tokens]
            else:
                input_ids = self.input_ids[:num_tokens]
                inputs_embeds = None
2233
2234
                model_kwargs = {}

2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
            if self.uses_mrope:
                positions = self.mrope_positions[:, :num_tokens]
            else:
                positions = self.positions[:num_tokens]

            if get_pp_group().is_first_rank:
                intermediate_tensors = None
            else:
                if self.intermediate_tensors is None:
                    self.intermediate_tensors = (
                        self.model.make_empty_intermediate_tensors(
                            batch_size=self.max_num_tokens,
                            dtype=self.model_config.dtype,
                            device=self.device))
2249
2250
2251

                intermediate_tensors = self.sync_and_slice_intermediate_tensors(
                    num_tokens, None, False)
2252

2253
            with self.maybe_randomize_inputs(input_ids), set_forward_context(
2254
2255
2256
2257
                    attn_metadata,
                    self.vllm_config,
                    num_tokens=num_tokens,
                    num_tokens_across_dp=num_tokens_across_dp):
2258
                outputs = model(
2259
2260
2261
2262
                    input_ids=input_ids,
                    positions=positions,
                    intermediate_tensors=intermediate_tensors,
                    inputs_embeds=inputs_embeds,
2263
2264
2265
2266
                    **MultiModalKwargs.as_kwargs(
                        model_kwargs,
                        device=self.device,
                    ),
2267
                )
2268

2269
2270
2271
2272
            if self.use_aux_hidden_state_outputs:
                hidden_states, _ = outputs
            else:
                hidden_states = outputs
2273

2274
            if self.speculative_config and self.speculative_config.use_eagle():
2275
2276
2277
                assert isinstance(self.drafter, EagleProposer)
                self.drafter.dummy_run(num_tokens)

2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
        # This is necessary to avoid blocking DP.
        # For dummy runs, we typically skip EPLB since we don't have any real
        # requests to process.
        # However, in DP settings, there may be cases when some DP ranks do
        # not have any requests to process, so they're executing dummy batches.
        # In such cases, we still have to trigger EPLB to make sure
        # ranks execute the rearrangement in synchronization.
        if not skip_eplb:
            self.eplb_step(is_dummy=True, is_profile=is_profile)

2288
        logit_indices = np.cumsum(num_scheduled_tokens) - 1
2289
        return hidden_states, hidden_states[logit_indices]
2290
2291
2292
2293
2294
2295

    @torch.inference_mode()
    def _dummy_sampler_run(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
2296
2297
2298
2299
        # The dummy hidden states may contain special values,
        # like `inf` or `nan`.
        # To avoid breaking the sampler, we use a random tensor here instead.
        hidden_states = torch.rand_like(hidden_states)
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322

        logits = self.model.compute_logits(hidden_states, None)
        num_reqs = logits.size(0)

        dummy_tensors = lambda v: torch.full(
            (num_reqs, ), v, device=self.device)

        dummy_metadata = SamplingMetadata(
            temperature=dummy_tensors(0.5),
            all_greedy=False,
            all_random=False,
            top_p=dummy_tensors(0.9),
            top_k=dummy_tensors(logits.size(1) - 1),
            generators={},
            max_num_logprobs=None,
            no_penalties=True,
            prompt_token_ids=None,
            frequency_penalties=dummy_tensors(0.1),
            presence_penalties=dummy_tensors(0.1),
            repetition_penalties=dummy_tensors(0.1),
            output_token_ids=[[] for _ in range(num_reqs)],
            allowed_token_ids_mask=None,
            bad_words_token_ids={},
2323
            logitsprocs=LogitsProcessorManager(),
2324
        )
2325
        try:
2326
2327
            sampler_output = self.sampler(logits=logits,
                                          sampling_metadata=dummy_metadata)
2328
2329
2330
2331
2332
2333
2334
2335
2336
        except RuntimeError as e:
            if 'out of memory' in str(e):
                raise RuntimeError(
                    "CUDA out of memory occurred when warming up sampler with "
                    f"{num_reqs} dummy requests. Please try lowering "
                    "`max_num_seqs` or `gpu_memory_utilization` when "
                    "initializing the engine.") from e
            else:
                raise e
2337
        if self.speculative_config:
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
            draft_token_ids = [[0] for _ in range(num_reqs)]
            dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy(
                draft_token_ids, self.device)

            num_tokens = sum(len(ids) for ids in draft_token_ids)
            # draft_probs = torch.randn(
            #     num_tokens, logits.shape[-1], device=self.device,
            #     dtype=logits.dtype)
            draft_probs = None
            target_logits = torch.randn(num_tokens,
                                        logits.shape[-1],
                                        device=self.device,
                                        dtype=logits.dtype)
            # NOTE(woosuk): Here, we should use int32 because the sampler uses
            # int32 for bonus_token_ids. If the dtype mismatches, re-compilation
            # will occur at runtime.
            bonus_token_ids = torch.zeros(num_reqs,
                                          device=self.device,
                                          dtype=torch.int32)
            self.rejection_sampler(
                dummy_spec_decode_metadata,
                draft_probs,
                target_logits,
                bonus_token_ids,
                dummy_metadata,
            )
2364
        return sampler_output
2365

2366
    def _dummy_pooler_run_task(
2367
2368
        self,
        hidden_states: torch.Tensor,
2369
2370
        task: PoolingTask,
    ) -> PoolerOutput:
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
        num_tokens = hidden_states.shape[0]
        max_num_reqs = self.scheduler_config.max_num_seqs
        num_reqs = min(num_tokens, max_num_reqs)
        min_tokens_per_req = num_tokens // num_reqs
        num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
        num_scheduled_tokens_list[-1] += num_tokens % num_reqs
        assert sum(num_scheduled_tokens_list) == num_tokens
        assert len(num_scheduled_tokens_list) == num_reqs

        hidden_states_list = list(
            torch.split(hidden_states, num_scheduled_tokens_list))
        req_num_tokens = num_tokens // num_reqs

2384
2385
2386
2387
2388
2389
2390
        dummy_prompt_lens = torch.tensor(
            [h.shape[0] for h in hidden_states_list],
            device=self.device,
        )
        dummy_token_ids = torch.zeros((num_reqs, req_num_tokens),
                                      dtype=torch.int32,
                                      device=self.device)
2391

2392
2393
2394
        model = cast(VllmModelForPooling, self.model)
        dummy_pooling_params = PoolingParams(task=task)
        to_update = model.pooler.get_pooling_updates(task)
2395
2396
        to_update.apply(dummy_pooling_params)

2397
        dummy_metadata = PoolingMetadata(
2398
2399
2400
2401
            prompt_lens=dummy_prompt_lens,
            prompt_token_ids=dummy_token_ids,
            pooling_params=[dummy_pooling_params] * num_reqs,
        )
2402
2403

        try:
2404
2405
            return model.pooler(hidden_states=hidden_states_list,
                                pooling_metadata=dummy_metadata)
2406
2407
2408
        except RuntimeError as e:
            if 'out of memory' in str(e):
                raise RuntimeError(
2409
2410
2411
                    "CUDA out of memory occurred when warming up pooler "
                    f"({task=}) with {num_reqs} dummy requests. Please try "
                    "lowering `max_num_seqs` or `gpu_memory_utilization` when "
2412
2413
2414
                    "initializing the engine.") from e
            else:
                raise e
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430

    @torch.inference_mode()
    def _dummy_pooler_run(
        self,
        hidden_states: torch.Tensor,
    ) -> PoolerOutput:
        # Find the task that has the largest output for subsequent steps
        output_size = dict[PoolingTask, float]()
        for task in self.get_supported_pooling_tasks():
            # Run a full batch with each task to ensure none of them OOMs
            output = self._dummy_pooler_run_task(hidden_states, task)
            output_size[task] = output.get_data_nbytes()
            del output  # Allow GC

        max_task = max(output_size.items(), key=lambda x: x[1])[0]
        return self._dummy_pooler_run_task(hidden_states, max_task)
2431

2432
    def profile_run(self) -> None:
2433
        # Profile with multimodal encoder & encoder cache.
2434
2435
2436
        # TODO: handle encoder-decoder models once we support them.
        if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
                and self.encoder_cache_size > 0):
2437

2438
            # NOTE: Currently model is profiled with a single non-text
2439
2440
            # modality with the max possible input tokens even when
            # it supports multiple.
2441
2442
            max_tokens_by_modality_dict = self.mm_registry \
                .get_max_tokens_per_item_by_nonzero_modality(self.model_config)
2443
2444
2445
2446
            dummy_data_modality, max_tokens_per_mm_item = max(
                max_tokens_by_modality_dict.items(), key=lambda item: item[1])

            # Check how many items of this modality can be supported by
2447
2448
2449
2450
            # the encoder budget.
            encoder_budget = min(self.max_num_encoder_input_tokens,
                                 self.encoder_cache_size)

2451
2452
            max_num_mm_items_encoder_budget = encoder_budget // \
                max_tokens_per_mm_item
2453
2454
2455

            # Check how many items of this modality can be supported by
            # the decoder budget.
2456
2457
            max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
                self.model_config)[dummy_data_modality]
2458
2459
2460
2461
2462
2463
2464

            # NOTE: We do not consider max_num_batched_tokens on purpose
            # because the multimodal embeddings can be generated in advance
            # and chunked prefilled.
            max_num_mm_items_decoder_budget = self.max_num_reqs * \
                max_mm_items_per_req

2465
2466
2467
2468
            max_num_mm_items = max(
                1,
                min(max_num_mm_items_encoder_budget,
                    max_num_mm_items_decoder_budget))
2469

2470
2471
2472
2473
2474
2475
            logger.info(
                "Encoder cache will be initialized with a budget of %s tokens,"
                " and profiled with %s %s items of the maximum feature size.",
                encoder_budget, max_num_mm_items, dummy_data_modality)

            # Create dummy batch of multimodal inputs.
2476
            dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data(
2477
                model_config=self.model_config,
2478
                seq_len=max_tokens_per_mm_item,
2479
2480
2481
2482
                mm_counts={
                    dummy_data_modality: 1
                },
            ).multi_modal_data
2483

2484
            batched_dummy_mm_inputs = MultiModalKwargs.batch(
2485
2486
                [dummy_mm_kwargs] * max_num_mm_items,
                pin_memory=self.pin_memory)
2487
            batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
2488
2489
2490
                batched_dummy_mm_inputs,
                device=self.device,
            )
2491
2492
2493
2494

            # Run multimodal encoder.
            dummy_encoder_outputs = self.model.get_multimodal_embeddings(
                **batched_dummy_mm_inputs)
2495
2496
2497
2498
2499

            sanity_check_mm_encoder_outputs(
                dummy_encoder_outputs,
                expected_num_items=max_num_mm_items,
            )
2500
2501
2502
2503

            # Cache the dummy encoder outputs.
            self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))

2504
        # Add `is_profile` here to pre-allocate communication buffers
2505
        hidden_states, last_hidden_states \
2506
            = self._dummy_run(self.max_num_tokens, is_profile=True)
2507
        if get_pp_group().is_last_rank:
2508
2509
2510
2511
            if self.is_pooling_model:
                output = self._dummy_pooler_run(hidden_states)
            else:
                output = self._dummy_sampler_run(last_hidden_states)
2512
        else:
2513
            output = None
2514
        self._sync_device()
2515
        del hidden_states, output
2516
        self.encoder_cache.clear()
2517
        gc.collect()
2518
2519

    def capture_model(self) -> None:
2520
2521
        if not self.use_cuda_graph:
            logger.warning(
2522
2523
2524
                "Skipping CUDA graph capture. To turn on CUDA graph capture, "
                "set -O %s and ensure `use_cudagraph` was not manually set to "
                "False", CompilationLevel.PIECEWISE)
2525
2526
            return

2527
2528
        compilation_counter.num_gpu_runner_capture_triggers += 1

2529
2530
2531
        start_time = time.perf_counter()
        start_free_gpu_memory = torch.cuda.mem_get_info()[0]

2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
        @contextmanager
        def freeze_gc():
            # Optimize garbage collection during CUDA graph capture.
            # Clean up, then freeze all remaining objects from being included
            # in future collections.
            gc.collect()
            should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC
            if should_freeze:
                gc.freeze()
            try:
                yield
            finally:
                if should_freeze:
                    gc.unfreeze()

2547
2548
2549
        # Trigger CUDA graph capture for specific shapes.
        # Capture the large shapes first so that the smaller shapes
        # can reuse the memory pool allocated for the large shapes.
2550
        with freeze_gc(), graph_capture(device=self.device):
2551
            full_cg = self.full_cuda_graph
2552
2553
2554
            # Only rank 0 should print progress bar during capture
            compilation_cases = reversed(self.cudagraph_batch_sizes)
            if is_global_first_rank():
2555
2556
2557
2558
                compilation_cases = tqdm(
                    list(compilation_cases),
                    disable=not self.load_config.use_tqdm_on_load,
                    desc="Capturing CUDA graph shapes")
2559
            for num_tokens in compilation_cases:
2560
                # We skip EPLB here since we don't want to record dummy metrics
2561
2562
                for _ in range(
                        self.compilation_config.cudagraph_num_of_warmups):
2563
2564
2565
2566
2567
2568
                    self._dummy_run(num_tokens,
                                    capture_attn_cudagraph=full_cg,
                                    skip_eplb=True)
                self._dummy_run(num_tokens,
                                capture_attn_cudagraph=full_cg,
                                skip_eplb=True)
2569
2570
2571
2572
2573
2574
2575
2576

        end_time = time.perf_counter()
        end_free_gpu_memory = torch.cuda.mem_get_info()[0]
        elapsed_time = end_time - start_time
        cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
        # This usually takes 5~20 seconds.
        logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
                    elapsed_time, cuda_graph_size / (1 << 30))
2577

2578
    def _initialize_single_attn_backend(
2579
        self, kv_cache_spec: KVCacheSpec, layer_names: list[str]
2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
2598
2599
2600
2601
    ) -> tuple[AttentionBackend, AttentionMetadataBuilder]:
        if isinstance(kv_cache_spec, AttentionSpec):
            attn_backend_i = get_attn_backend(
                kv_cache_spec.head_size,
                self.dtype,
                kv_cache_spec.dtype,
                kv_cache_spec.block_size,
                self.model_config.is_attention_free,
                use_mla=kv_cache_spec.use_mla,
            )
            if attn_backend_i is None:
                error_msg = (f"Error with get_attn_backend: "
                             f"{kv_cache_spec.head_size=}, "
                             f"{self.dtype=}, {kv_cache_spec.dtype=}, "
                             f"{kv_cache_spec.block_size=}, "
                             f"{self.model_config.is_attention_free=}, "
                             f"{kv_cache_spec.use_mla=}")
                logger.error(error_msg)
                raise NotImplementedError(
                    "Non-Attention backend is not supported by V1 "
                    "GPUModelRunner.")
        elif isinstance(kv_cache_spec, MambaSpec):
2602
            attn_backend_i = get_mamba_attn_backend(kv_cache_spec.mamba_type)
2603
2604
2605
2606
2607
2608
        else:
            raise ValueError(
                f"Unknown KV cache spec type: {type(kv_cache_spec)}")

        attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
            kv_cache_spec,
2609
            layer_names,
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
            self.vllm_config,
            self.device,
        )

        if (self.full_cuda_graph
                and not attn_metadata_builder_i.full_cudagraph_supported):
            raise ValueError(
                f"Full CUDAGraph not supported for "
                f"{attn_backend_i.__name__}. Turn off CompilationConfig."
                f"full_cuda_graph or use a different attention backend.")
        return attn_backend_i, attn_metadata_builder_i

2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
    def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize the attention backends and attention metadata builders.
        """
        assert len(self.attn_backends) == 0 and len(
            self.attn_metadata_builders
        ) == 0, "Attention backends are already initialized"
        for i, kv_cache_group_spec in enumerate(
                kv_cache_config.kv_cache_groups):
            kv_cache_spec = kv_cache_group_spec.kv_cache_spec
2632

2633
2634
2635
            attn_backend_i, attn_metadata_builder_i = (
                self._initialize_single_attn_backend(
                    kv_cache_spec, kv_cache_group_spec.layer_names))
2636
2637
2638
            self.attn_backends.append(attn_backend_i)
            self.attn_metadata_builders.append(attn_metadata_builder_i)

2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
        if len(self.attn_backends) > 0:
            return

        # Check if model is encoder-only
        block_size = self.vllm_config.cache_config.block_size
        use_mla = self.vllm_config.model_config.use_mla
        attn_specs = list[AttentionSpec]()
        attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
        for attn_module in attn_layers.values():

            if attn_module.attn_type == AttentionType.ENCODER_ONLY:
                assert attn_module.sliding_window is None, "Sliding "
                "window attention is not supported for encoder-only models"

                attn_specs.append(
                    FullAttentionSpec(block_size=block_size,
                                      num_kv_heads=attn_module.num_kv_heads,
                                      head_size=attn_module.head_size,
                                      dtype=self.kv_cache_dtype,
                                      use_mla=use_mla))
            else:
                raise ValueError("Expected only encoder-only layers")

        if len(attn_specs) > 0:
            assert len(attn_specs) == len(attn_layers), \
                "All or none of the layers are expected to be encoder-only"

2666
2667
2668
            attn_backend, attn_metadata_builder = (
                self._initialize_single_attn_backend(attn_specs[0],
                                                     attn_layers.keys()))
2669
2670
2671
2672
            self.attn_backends.append(attn_backend)
            self.attn_metadata_builders.append(attn_metadata_builder)
            self.is_encoder_only_model = True

2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
    def may_reinitialize_input_batch(self,
                                     kv_cache_config: KVCacheConfig) -> None:
        """
        Re-initialize the input batch if the block sizes are different from
        `[self.cache_config.block_size]`. This usually happens when there
        are multiple KV cache groups.

        Args:
            kv_cache_config: The KV cache configuration.
        """
        block_sizes = [
            kv_cache_group.kv_cache_spec.block_size
            for kv_cache_group in kv_cache_config.kv_cache_groups
        ]
        if block_sizes != [self.cache_config.block_size]:
            assert self.cache_config.cpu_offload_gb == 0, (
                "Cannot re-initialize the input batch when CPU weight "
                "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 "  # noqa: E501
                "for more details.")
            self.input_batch = InputBatch(
                max_num_reqs=self.max_num_reqs,
                max_model_len=self.max_model_len,
                max_num_batched_tokens=self.max_num_tokens,
                device=self.device,
                pin_memory=self.pin_memory,
                vocab_size=self.model_config.get_vocab_size(),
                block_sizes=block_sizes,
2700
                is_spec_decode=bool(self.vllm_config.speculative_config),
2701
2702
            )

2703
2704
    def _allocate_kv_cache_tensors(
            self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
2705
        """
2706
2707
2708
        Initializes the KV cache buffer with the correct size. The buffer needs
        to be reshaped to the desired shape before being used by the models.

2709
        Args:
2710
            kv_cache_config: The KV cache config
2711
        Returns:
2712
            dict[str, torch.Tensor]: A map between layer names to their
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
            corresponding memory buffer for KV cache.
         """
        kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
        for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
            tensor = torch.zeros(kv_cache_tensor.size,
                                 dtype=torch.int8,
                                 device=self.device)
            for layer_name in kv_cache_tensor.shared_by:
                kv_cache_raw_tensors[layer_name] = tensor

        layer_names = set()
        for group in kv_cache_config.kv_cache_groups:
            layer_names.update(group.layer_names)
        assert layer_names == set(kv_cache_raw_tensors.keys(
        )), "Some layers are not correctly initialized"
        return kv_cache_raw_tensors

    def _reshape_kv_cache_tensors(
        self,
        kv_cache_config: KVCacheConfig,
        kv_cache_raw_tensors: dict[str, torch.Tensor],
    ) -> dict[str, torch.Tensor]:
2735
        """
2736
        Reshape the KV cache tensors to the desired shape and dtype.
2737

2738
        Args:
2739
2740
            kv_cache_config: The KV cache config
            kv_cache_raw_tensors: The KV cache buffer of each layer, with
2741
2742
            correct size but uninitialized shape.
        Returns:
2743
            Dict[str, torch.Tensor]: A map between layer names to their
2744
2745
            corresponding memory buffer for KV cache.
        """
2746
        kv_caches: dict[str, torch.Tensor] = {}
2747
        has_attn, has_mamba = False, False
2748
2749
2750
2751
2752
2753
2754
2755
        for i, kv_cache_group_spec in enumerate(
                kv_cache_config.kv_cache_groups):
            kv_cache_spec = kv_cache_group_spec.kv_cache_spec
            for layer_name in kv_cache_group_spec.layer_names:
                raw_tensor = kv_cache_raw_tensors[layer_name]
                assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
                num_blocks = (raw_tensor.numel() //
                              kv_cache_spec.page_size_bytes)
2756
                if isinstance(kv_cache_spec, AttentionSpec):
2757
                    has_attn = True
2758
                    kv_cache_shape = self.attn_backends[i].get_kv_cache_shape(
2759
2760
2761
                        num_blocks, kv_cache_spec.block_size,
                        kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
                    dtype = kv_cache_spec.dtype
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
                    try:
                        kv_cache_stride_order = self.attn_backends[
                            i].get_kv_cache_stride_order()
                        assert len(kv_cache_stride_order) == len(
                            kv_cache_shape)
                    except (AttributeError, NotImplementedError):
                        kv_cache_stride_order = tuple(
                            range(len(kv_cache_shape)))
                    # The allocation respects the backend-defined stride order
                    # to ensure the semantic remains consistent for each
                    # backend. We first obtain the generic kv cache shape and
                    # then permute it according to the stride order which could
                    # result in a non-contiguous tensor.
                    kv_cache_shape = tuple(kv_cache_shape[i]
                                           for i in kv_cache_stride_order)
                    # Maintain original KV shape view.
                    inv_order = [
                        kv_cache_stride_order.index(i)
                        for i in range(len(kv_cache_stride_order))
                    ]
2782
2783
2784
                    kv_caches[layer_name] = kv_cache_raw_tensors[
                        layer_name].view(dtype).view(kv_cache_shape).permute(
                            *inv_order)
Chen Zhang's avatar
Chen Zhang committed
2785
                elif isinstance(kv_cache_spec, MambaSpec):
2786
                    has_mamba = True
Chen Zhang's avatar
Chen Zhang committed
2787
2788
                    raw_tensor = kv_cache_raw_tensors[layer_name]
                    dtype = kv_cache_spec.dtype
2789
2790
                    num_element_per_page = (kv_cache_spec.page_size_bytes //
                                            get_dtype_size(dtype))
Chen Zhang's avatar
Chen Zhang committed
2791
                    state_tensors = []
2792
                    storage_offset = 0
Chen Zhang's avatar
Chen Zhang committed
2793
2794
                    for shape in kv_cache_spec.shapes:
                        target_shape = (num_blocks, *shape)
2795
2796
2797
2798
2799
2800
2801
2802
                        stride = torch.empty(target_shape).stride()
                        target_stride = (num_element_per_page, *stride[1:])
                        tensor = torch.as_strided(
                            raw_tensor.view(dtype),
                            size=target_shape,
                            stride=target_stride,
                            storage_offset=storage_offset,
                        )
Chen Zhang's avatar
Chen Zhang committed
2803
                        state_tensors.append(tensor)
2804
2805
2806
                        storage_offset += stride[0]

                    kv_caches[layer_name] = state_tensors
2807
                else:
2808
                    raise NotImplementedError
2809
2810
2811
2812
2813

        if has_attn and has_mamba:
            self._verify_hybrid_attention_mamba_layout(kv_cache_config,
                                                       kv_cache_raw_tensors)

2814
2815
        return kv_caches

2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
    def _verify_hybrid_attention_mamba_layout(
            self, kv_cache_config: KVCacheConfig,
            kv_cache_raw_tensors: dict[str, torch.Tensor]) -> None:
        """
        Verify that the KV cache memory layout is compatible for
        models with both attention and mamba KV cache groups.

        Args:
            kv_cache_config: The KV cache config
            kv_cache_raw_tensors: The KV cache buffer of each layer.
        """

        for i, kv_cache_group_spec in enumerate(
                kv_cache_config.kv_cache_groups):
            kv_cache_spec = kv_cache_group_spec.kv_cache_spec
            for layer_name in kv_cache_group_spec.layer_names:
                raw_tensor = kv_cache_raw_tensors[layer_name]
                num_blocks = (raw_tensor.numel() //
                              kv_cache_spec.page_size_bytes)
                if isinstance(kv_cache_spec, AttentionSpec):
                    kv_cache_shape = self.attn_backends[i].get_kv_cache_shape(
                        num_blocks, kv_cache_spec.block_size,
                        kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
                    if kv_cache_shape[0] != num_blocks or kv_cache_shape[
                            1] != 2:
                        raise ValueError(
                            "Hybrid models in V1 require an attention "
                            "backend with kv_cache_shape="
                            "(num_blocks, 2, ...). Please try setting "
                            "VLLM_ATTENTION_BACKEND=FLASHINFER")

2847
2848
2849
2850
2851
2852
2853
2854
    def initialize_kv_cache_tensors(
            self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
        """
        Initialize the memory buffer for KV cache.

        Args:
            kv_cache_config: The KV cache config
        Returns:
2855
            Dict[str, torch.Tensor]: A map between layer names to their
2856
2857
2858
2859
2860
2861
2862
            corresponding memory buffer for KV cache.
        """
        # Initialize the memory buffer for KV cache
        kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
        # Change the memory buffer to the desired shape
        kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
                                                   kv_cache_raw_tensors)
2863

2864
2865
2866
2867
2868
2869
2870
2871
        # Setup `kv_cache_config` and `kv_caches` for models
        # with cross-layer KV sharing
        if self.shared_kv_cache_layers:
            initialize_kv_cache_for_kv_sharing(
                self.shared_kv_cache_layers,
                kv_cache_config.kv_cache_groups,
                kv_caches,
            )
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
            attn_layers = get_layers_from_vllm_config(self.vllm_config,
                                                      Attention)
            # Iterate in reversed order and add layers that re-use KV cache
            # e.g. in YOCO-like KV sharing setups (e.g. Gemma3n)
            for layer_name in reversed(attn_layers):
                if layer_name in self.shared_kv_cache_layers:
                    self.kv_sharing_fast_prefill_eligible_layers.add(
                        layer_name)
                else:
                    break
2882

2883
2884
2885
        bind_kv_cache(kv_caches,
                      self.compilation_config.static_forward_context,
                      self.kv_caches)
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
        return kv_caches

    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize KV cache based on `kv_cache_config`.
        Args:
            kv_cache_config: Configuration for the KV cache, including the KV
            cache size of each layer
        """
        self.kv_cache_config = kv_cache_config
        self.may_reinitialize_input_batch(kv_cache_config)
        self.initialize_attn_backend(kv_cache_config)
        kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)

2900
2901
2902
2903
2904
2905
        if self.speculative_config and self.speculative_config.use_eagle():
            assert isinstance(self.drafter, EagleProposer)
            # validate all draft model layers belong to the same kv cache
            # group
            self.drafter.validate_same_kv_cache_group(kv_cache_config)

Robert Shaw's avatar
Robert Shaw committed
2906
2907
2908
        if has_kv_transfer_group():
            get_kv_transfer_group().register_kv_caches(kv_caches)

2909
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
2910
        """
2911
        Generates the KVCacheSpec by parsing the kv cache format from each
2912
2913
        Attention module in the static forward context.
        Returns:
2914
            KVCacheSpec: A dictionary mapping layer names to their KV cache
2915
2916
2917
2918
            format. Layers that do not need KV cache are not included.
        """

        block_size = self.vllm_config.cache_config.block_size
2919
        use_mla = self.vllm_config.model_config.use_mla
2920
        kv_cache_spec: dict[str, KVCacheSpec] = {}
Chen Zhang's avatar
Chen Zhang committed
2921
2922
        attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
        for layer_name, attn_module in attn_layers.items():
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
            if (kv_tgt_layer :=
                    attn_module.kv_sharing_target_layer_name) is not None:
                # The layer doesn't need its own KV cache and will use that of
                # the target layer. We skip creating a KVCacheSpec for it, so
                # that KV cache management logic will act as this layer does
                # not exist, and doesn't allocate KV cache for the layer. This
                # enables the memory saving of cross-layer kv sharing, allowing
                # a given amount of memory to accommodate longer context lengths
                # or enable more requests to be processed simultaneously.
                self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
                continue

2935
            # TODO: Support other attention modules, e.g., cross-attention
2936
            if attn_module.attn_type == AttentionType.DECODER:
2937
                use_local_attention = (self.attention_chunk_size is not None
2938
                                       and attn_module.use_irope)
2939
2940
2941
2942
2943
2944
2945
2946
                if attn_module.sliding_window is not None:
                    kv_cache_spec[layer_name] = SlidingWindowSpec(
                        block_size=block_size,
                        num_kv_heads=attn_module.num_kv_heads,
                        head_size=attn_module.head_size,
                        dtype=self.kv_cache_dtype,
                        sliding_window=attn_module.sliding_window,
                        use_mla=use_mla)
2947
2948
2949
                    assert not use_local_attention, (
                        "attention module can not be with ",
                        "both local attention and sliding window")
2950
                elif use_local_attention:
2951
                    kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
2952
2953
2954
2955
2956
                        block_size=block_size,
                        num_kv_heads=attn_module.num_kv_heads,
                        head_size=attn_module.head_size,
                        dtype=self.kv_cache_dtype,
                        attention_chunk_size=self.attention_chunk_size,
2957
                        use_mla=use_mla)
2958
2959
2960
2961
2962
2963
2964
                else:
                    kv_cache_spec[layer_name] = FullAttentionSpec(
                        block_size=block_size,
                        num_kv_heads=attn_module.num_kv_heads,
                        head_size=attn_module.head_size,
                        dtype=self.kv_cache_dtype,
                        use_mla=use_mla)
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
            elif attn_module.attn_type in (AttentionType.ENCODER,
                                           AttentionType.ENCODER_ONLY):
                # encoder-only attention does not need KV cache.
                continue
            elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
                raise NotImplementedError
            else:
                raise ValueError(
                    f"Unknown attention type: {attn_module.attn_type}")

2975
        mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
Chen Zhang's avatar
Chen Zhang committed
2976
2977
2978
2979
2980
2981
2982
2983
        if len(mamba_layers) > 0:
            if self.vllm_config.speculative_config is not None:
                raise NotImplementedError(
                    "Mamba with speculative decoding is not supported yet.")
            if self.vllm_config.cache_config.enable_prefix_caching:
                raise NotImplementedError(
                    "Prefix caching is not supported for Mamba yet.")
            max_model_len = self.vllm_config.model_config.max_model_len
2984

2985
2986
            page_size_padded = (
                self.vllm_config.cache_config.mamba_page_size_padded)
2987

Chen Zhang's avatar
Chen Zhang committed
2988
2989
2990
2991
2992
2993
            # Set block_size to max_model_len, so that mamba model will always
            # have only one block in the KV cache.
            for layer_name, mamba_module in mamba_layers.items():
                kv_cache_spec[layer_name] = MambaSpec(
                    shapes=mamba_module.get_state_shape(),
                    dtype=self.kv_cache_dtype,
2994
                    block_size=max_model_len,
2995
2996
                    page_size_padded=page_size_padded,
                    mamba_type=mamba_module.mamba_type)
2997

2998
        return kv_cache_spec
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021
3022
3023
3024
3025
3026
3027
3028
3029
3030
3031
3032
3033
3034
3035
3036
3037
3038
3039
3040
3041
3042
3043
3044
3045
3046
3047
3048

    def _build_encoder_only_attn_metadata(
            self, scheduler_output: "SchedulerOutput") -> \
                tuple[CommonAttentionMetadata, Any]:
        """Prepare encoder attention metadata for encoder-only models.

        Args:
            scheduler_output: Scheduler output

        Returns:
            dict[str, Any]: Encoder attention metadata
        """
        num_reqs = self.input_batch.num_reqs
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens

        # Get the number of scheduled tokens for each request.
        req_ids = self.input_batch.req_ids
        tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
        max_num_scheduled_tokens = max(tokens)

        # Use the first attention metadata builder
        # to create encoder attention metadata
        builder = self.attn_metadata_builders[0]

        dummy_block_table = torch.zeros((num_reqs, 1),
                                        dtype=torch.int32,
                                        device=self.device)
        dummy_slot_mapping = torch.zeros((total_num_scheduled_tokens, ),
                                         dtype=torch.int32,
                                         device=self.device)

        common_metadata = CommonAttentionMetadata(
            query_start_loc=self.query_start_loc[:num_reqs + 1],
            query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
            seq_lens=self.seq_lens[:num_reqs],
            seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
            num_computed_tokens_cpu=self.input_batch.
            num_computed_tokens_cpu_tensor[:num_reqs],
            num_reqs=num_reqs,
            num_actual_tokens=total_num_scheduled_tokens,
            max_query_len=max_num_scheduled_tokens,
            block_table_tensor=dummy_block_table,
            slot_mapping=dummy_slot_mapping,
            causal=False,
        )

        return common_metadata, builder.build(
            common_prefix_len=0,  # No cascade for encoder
            common_attn_metadata=common_metadata,
        )