gpu_model_runner.py 74.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import gc
4
import time
5
import weakref
6
from typing import TYPE_CHECKING, Optional, Union
7
8
9
10
11
12

import numpy as np
import torch
import torch.distributed
import torch.nn as nn

13
from vllm.attention import AttentionType, get_attn_backend
14
from vllm.attention.layer import Attention
15
from vllm.config import CompilationLevel, VllmConfig
16
from vllm.distributed.parallel_state import get_pp_group, graph_capture
17
18
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
19
from vllm.model_executor.layers.fused_moe import FusedMoE
20
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
21
from vllm.model_executor.model_loader import get_model
22
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
23
from vllm.multimodal.utils import group_mm_inputs_by_modality
24
from vllm.sampling_params import SamplingType
25
from vllm.sequence import IntermediateTensors
26
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
27
28
                        GiB_bytes, LayerBlockType, LazyLoader, cdiv,
                        check_use_alibi, is_pin_memory_available)
29
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
30
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
31
32
33
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
                                        KVCacheConfig, KVCacheSpec,
                                        SlidingWindowSpec)
34
35
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
                             ModelRunnerOutput)
36
from vllm.v1.sample.metadata import SamplingMetadata
37
38
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
39
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
40
from vllm.v1.spec_decode.utils import is_spec_decode_supported
41
from vllm.v1.utils import bind_kv_cache
42
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
43
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
44

45
46
from .utils import sanity_check_mm_encoder_outputs

47
if TYPE_CHECKING:
48
49
    import xgrammar as xgr

50
    from vllm.v1.core.sched.output import SchedulerOutput
51
52
else:
    xgr = LazyLoader("xgr", globals(), "xgrammar")
53
54
55
56

logger = init_logger(__name__)


57
class GPUModelRunner(LoRAModelRunnerMixin):
58
59
60

    def __init__(
        self,
61
        vllm_config: VllmConfig,
62
        device: torch.device,
63
    ):
64
65
66
67
68
69
70
71
72
73
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.prompt_adapter_config = vllm_config.prompt_adapter_config
        self.observability_config = vllm_config.observability_config
74

75
76
77
78
        from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
        set_cpu_offload_max_bytes(
            int(self.cache_config.cpu_offload_gb * 1024**3))

79
80
81
82
        model_config = self.model_config
        cache_config = self.cache_config
        scheduler_config = self.scheduler_config
        parallel_config = self.parallel_config
83
        self.device = device
84
85
86
87
88
89
90
91
        self.pin_memory = is_pin_memory_available()
        self.dtype = self.model_config.dtype
        if cache_config.cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        else:
            self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
                cache_config.cache_dtype]

92
93
        # NOTE(woosuk): sliding_window is None for models with interleaved
        # attention. Use interleaved_sliding_window instead.
94
        self.sliding_window = model_config.get_sliding_window()
95
96
97
98
99
100
        self.interleaved_sliding_window = getattr(
            model_config.hf_text_config, "interleaved_sliding_window", None)
        self.window_size = (self.sliding_window
                            or self.interleaved_sliding_window)

        self.is_multimodal_model = model_config.is_multimodal_model
101
102
103
104
        self.block_size = cache_config.block_size
        self.max_model_len = model_config.max_model_len
        self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
        self.max_num_tokens = scheduler_config.max_num_batched_tokens
105
        self.max_num_reqs = scheduler_config.max_num_seqs
106
107

        # Model-related.
108
109
        self.num_attn_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)
110
111
        self.num_query_heads = model_config.get_num_attention_heads(
            parallel_config)
112
113
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
        self.head_size = model_config.get_head_size()
114
115
        self.hidden_size = model_config.get_hidden_size()

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
        self.attn_backend = get_attn_backend(
            self.head_size,
            self.dtype,
            self.kv_cache_dtype,
            self.block_size,
            self.model_config.is_attention_free,
            use_mla=self.model_config.use_mla,
        )
        if self.attn_backend is None:
            error_msg = (
                f"Error with get_att_backend: {self.head_size=}, "
                f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
                f"{self.model_config.is_attention_free=}, "
                f"{self.model_config.use_mla=}")
            logger.error(error_msg)
            raise NotImplementedError(
                "Non-Attention backend is not supported by V1 GPUModelRunner.")

        self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
            weakref.proxy(self))
136
        self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
137

138
        # Multi-modal data support
139
        self.mm_registry = MULTIMODAL_REGISTRY
140
        self.uses_mrope = model_config.uses_mrope
141

142
143
144
        encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
            model_config=model_config,
            scheduler_config=scheduler_config,
145
            mm_registry=self.mm_registry,
146
147
148
        )
        self.max_num_encoder_input_tokens = encoder_compute_budget
        self.encoder_cache_size = encoder_cache_size
149
150
151

        # Lazy initialization
        # self.model: nn.Module  # Set after load_model
152
        self.kv_caches: list[torch.Tensor] = []
153
        # req_id -> (input_id -> encoder_output)
154
        self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
155

156
157
158
        # Set up speculative decoding.
        self.use_spec_decode = False
        if self.speculative_config:
159
            self.use_spec_decode = True
160
            assert self.speculative_config.method == "ngram", \
161
                    "Currently, only ngram spec decode is supported in V1."
162
163
164
165
166
167
            if get_pp_group().is_last_rank:
                self.drafter = NgramProposer()
                # Trigger Numba JIT compilation for N-gram proposer.
                # This usually takes less than 1 second.
                self.drafter.propose(
                    np.zeros(1024, dtype=np.int32),
168
                    self.speculative_config.prompt_lookup_min,
169
                    self.speculative_config.prompt_lookup_max,
170
171
                    self.speculative_config.num_speculative_tokens,
                )
172
                self.rejection_sampler = RejectionSampler()
173

174
        # Request states.
175
        self.requests: dict[str, CachedRequestState] = {}
176
177
        # Persistent batch.
        self.input_batch = InputBatch(
178
            max_num_reqs=self.max_num_reqs,
179
180
181
182
            max_model_len=self.max_model_len,
            max_num_blocks_per_req=self.max_num_blocks_per_req,
            device=self.device,
            pin_memory=self.pin_memory,
183
            vocab_size=model_config.get_vocab_size(),
184
185
        )

186
        self.use_cuda_graph = (self.vllm_config.compilation_config.level
187
188
189
                               == CompilationLevel.PIECEWISE
                               and not self.model_config.enforce_eager)
        # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
190
191
192
193
        # The convention is different.
        # self.cudagraph_batch_sizes sorts in ascending order.
        # The batch sizes in the config are in descending order.
        self.cudagraph_batch_sizes = list(
194
195
            reversed(
                self.vllm_config.compilation_config.cudagraph_capture_sizes))
196

197
198
199
200
        # Cache the device properties.
        self.device_properties = torch.cuda.get_device_properties(self.device)
        self.num_sms = self.device_properties.multi_processor_count

201
202
203
204
        # Persistent buffers for CUDA graphs.
        self.input_ids = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int32,
                                     device=self.device)
205
206
207
        self.positions = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int64,
                                     device=self.device)
208
209
        # None in the first PP rank. The rest are set after load_model.
        self.intermediate_tensors: Optional[IntermediateTensors] = None
210
211

        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
212
        if self.uses_mrope:
Roger Wang's avatar
Roger Wang committed
213
214
215
216
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
217
218
219
220
221
222

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
Roger Wang's avatar
Roger Wang committed
223
            self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1),
224
225
                                               dtype=torch.int64,
                                               device=self.device)
Roger Wang's avatar
Roger Wang committed
226
227
228
229
230
            self.mrope_positions_cpu = torch.zeros(
                (3, self.max_num_tokens + 1),
                dtype=torch.int64,
                device="cpu",
                pin_memory=self.pin_memory)
231

232
233
234
        # Only relevant for models using ALiBi (e.g, MPT)
        self.use_alibi = check_use_alibi(model_config)

235
236
237
238
        self.inputs_embeds = torch.zeros(
            (self.max_num_tokens, self.hidden_size),
            dtype=self.dtype,
            device=self.device)
239

240
        # OPTIMIZATION: Cache the tensors rather than creating them every step.
241
        self.arange_np = np.arange(max(self.max_num_reqs + 1,
242
243
                                       self.max_model_len,
                                       self.max_num_tokens),
244
245
246
247
                                   dtype=np.int32)
        # NOTE(woosuk): These tensors are "stateless", i.e., they are literally
        # a faster version of creating a new tensor every time. Thus, we should
        # not make any assumptions about the values in these tensors.
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
        self.input_ids_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int32,
                                         device="cpu",
                                         pin_memory=self.pin_memory)
        self.input_ids_np = self.input_ids_cpu.numpy()
        self.positions_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int64,
                                         device="cpu",
                                         pin_memory=self.pin_memory)
        self.positions_np = self.positions_cpu.numpy()
        self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
                                            dtype=torch.int32,
                                            device="cpu",
                                            pin_memory=self.pin_memory)
        self.slot_mapping_np = self.slot_mapping_cpu.numpy()
        self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
                                               dtype=torch.int32,
                                               device="cpu",
                                               pin_memory=self.pin_memory)
        self.query_start_loc_np = self.query_start_loc_cpu.numpy()
268
269
270
271
272
        self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
                                        dtype=torch.int32,
                                        device="cpu",
                                        pin_memory=self.pin_memory)
        self.seq_lens_np = self.seq_lens_cpu.numpy()
273

274
    def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
275
276
277
278
279
280
        """Update the cached states and the persistent batch with the scheduler
        output.

        The updated states are used by the `_prepare_inputs` function to create
        the input GPU tensors for the model.

281
282
        The SamplingMetadata is updated and copied to the GPU if there is a
        new/resumed/paused/finished request in the batch.
283
284
        """
        # Remove finished requests from the cached states.
285
286
        for req_id in scheduler_output.finished_req_ids:
            self.requests.pop(req_id, None)
287
            self.encoder_cache.pop(req_id, None)
288
289
290
291
292
293
        # Remove the finished requests from the persistent batch.
        # NOTE(woosuk): There could be an edge case where finished_req_ids and
        # scheduled_req_ids overlap. This happens when a request is aborted and
        # then resubmitted with the same ID. In this case, we treat them as two
        # distinct requests - clearing the cached states for the first request
        # and handling the second as a new request.
294
        removed_req_indices: list[int] = []
295
296
297
298
        for req_id in scheduler_output.finished_req_ids:
            req_index = self.input_batch.remove_request(req_id)
            if req_index is not None:
                removed_req_indices.append(req_index)
299
300
301
302
303
304
305
306

        # Free the cached encoder outputs.
        for req_id, input_id in scheduler_output.free_encoder_input_ids:
            encoder_outputs = self.encoder_cache.get(req_id)
            if encoder_outputs is not None:
                encoder_outputs.pop(input_id, None)
                if not encoder_outputs:
                    self.encoder_cache.pop(req_id, None)
307

308
309
310
311
312
313
314
315
316
317
318
319
320
        # Remove the unscheduled requests from the persistent batch.
        # NOTE(woosuk): The unscheduled requests are either preempted requests
        # or running requests that are not scheduled in this step. We remove
        # them from the persistent batch but keep their cached states since
        # they will be scheduled again sometime in the future.
        scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
        cached_req_ids = self.input_batch.req_id_to_index.keys()
        unscheduled_req_ids = cached_req_ids - scheduled_req_ids
        # NOTE(woosuk): The persistent batch optimization assumes that
        # consecutive batches contain mostly the same requests. If batches
        # have low request overlap (e.g., alternating between two distinct
        # sets of requests), this optimization becomes very inefficient.
        for req_id in unscheduled_req_ids:
321
            req_index = self.input_batch.remove_request(req_id)
322
323
            assert req_index is not None
            removed_req_indices.append(req_index)
324

325
        req_ids_to_add: list[str] = []
326
        # Add new requests to the cached states.
327
328
329
        for new_req_data in scheduler_output.scheduled_new_reqs:
            req_id = new_req_data.req_id
            sampling_params = new_req_data.sampling_params
330
            if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
331
332
333
334
335
                generator = torch.Generator(device=self.device)
                generator.manual_seed(sampling_params.seed)
            else:
                generator = None

336
337
            self.requests[req_id] = CachedRequestState(
                req_id=req_id,
338
339
340
341
                prompt_token_ids=new_req_data.prompt_token_ids,
                prompt=new_req_data.prompt,
                mm_inputs=new_req_data.mm_inputs,
                mm_positions=new_req_data.mm_positions,
342
343
                sampling_params=sampling_params,
                generator=generator,
344
345
                block_ids=new_req_data.block_ids,
                num_computed_tokens=new_req_data.num_computed_tokens,
346
                output_token_ids=[],
347
                lora_request=new_req_data.lora_request,
348
            )
349
350

            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
351
            if self.uses_mrope:
352
353
                image_grid_thw = []
                video_grid_thw = []
Roger Wang's avatar
Roger Wang committed
354
                second_per_grid_ts = []
355
356
357
358
359
360
361
                for mm_input in self.requests[req_id].mm_inputs:
                    if mm_input.get("image_grid_thw") is not None:
                        image_grid_thw.extend(
                            mm_input["image_grid_thw"].tolist())
                    if mm_input.get("video_grid_thw") is not None:
                        video_grid_thw.extend(
                            mm_input["video_grid_thw"].tolist())
Roger Wang's avatar
Roger Wang committed
362
363
364
                    if mm_input.get("second_per_grid_ts") is not None:
                        second_per_grid_ts.extend(
                            mm_input["second_per_grid_ts"])
365
366
367
368
369
370
371

                hf_config = self.model_config.hf_config

                self.requests[req_id].mrope_positions, \
                    self.requests[req_id].mrope_position_delta = \
                    MRotaryEmbedding.get_input_positions_tensor(
                        self.requests[req_id].prompt_token_ids,
Roger Wang's avatar
Roger Wang committed
372
                        hf_config=hf_config,
373
374
                        image_grid_thw=image_grid_thw,
                        video_grid_thw=video_grid_thw,
Roger Wang's avatar
Roger Wang committed
375
                        second_per_grid_ts=second_per_grid_ts,
376
377
                    )

378
379
            req_ids_to_add.append(req_id)

380
381
382
        # Update the states of the running/resumed requests.
        for req_data in scheduler_output.scheduled_cached_reqs:
            req_id = req_data.req_id
383
384
            req_state = self.requests[req_id]

385
            # Update the cached states.
386
387
388
389
390
391
392
            num_computed_tokens = req_data.num_computed_tokens
            req_state.num_computed_tokens = num_computed_tokens
            # Add the sampled token(s) from the previous step (if any).
            # This doesn't include "unverified" tokens like spec decode tokens.
            num_new_tokens = (num_computed_tokens +
                              len(req_data.new_token_ids) -
                              req_state.num_tokens)
393
394
395
396
397
398
            if num_new_tokens == 1:
                # Avoid slicing list in most common case.
                req_state.output_token_ids.append(req_data.new_token_ids[-1])
            elif num_new_tokens > 0:
                req_state.output_token_ids.extend(
                    req_data.new_token_ids[-num_new_tokens:])
399
            # Update the block IDs.
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
            if not req_data.resumed_from_preemption:
                # Append the new blocks to the existing block IDs.
                req_state.block_ids.extend(req_data.new_block_ids)
            else:
                # The request is resumed from preemption.
                # Replace the existing block IDs with the new ones.
                req_state.block_ids = req_data.new_block_ids

            req_index = self.input_batch.req_id_to_index.get(req_id)
            if req_index is None:
                # The request is not in the persistent batch.
                # The request was either preempted and resumed later, or was not
                # scheduled in the previous step and needs to be added again.
                req_ids_to_add.append(req_id)
                continue

            # Update the persistent batch.
            self.input_batch.num_computed_tokens_cpu[req_index] = (
418
                num_computed_tokens)
419
420
            self.input_batch.block_table.append_row(req_data.new_block_ids,
                                                    req_index)
421
422
423
424
425
426
            # Add new_token_ids to token_ids_cpu.
            start_token_index = num_computed_tokens
            end_token_index = num_computed_tokens + len(req_data.new_token_ids)
            self.input_batch.token_ids_cpu[
                req_index,
                start_token_index:end_token_index] = req_data.new_token_ids
427
            self.input_batch.num_tokens_no_spec[req_index] = end_token_index
428
429
            # Add spec_token_ids to token_ids_cpu.
            spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
430
                req_id, ())
431
432
433
434
435
436
437
            if spec_token_ids:
                start_index = end_token_index
                end_token_index += len(spec_token_ids)
                self.input_batch.token_ids_cpu[
                    req_index, start_index:end_token_index] = spec_token_ids
            # NOTE(woosuk): `num_tokens` here may include spec decode tokens.
            self.input_batch.num_tokens[req_index] = end_token_index
438

439
440
        # Check if the batch has changed. If not, we can skip copying the
        # sampling metadata from CPU to GPU.
441
442
        batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0

443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
        # Add the new or resumed requests to the persistent batch.
        # The smaller empty indices are filled first.
        removed_req_indices = sorted(removed_req_indices, reverse=True)
        for req_id in req_ids_to_add:
            req_state = self.requests[req_id]
            if removed_req_indices:
                # Fill the empty index.
                req_index = removed_req_indices.pop()
            else:
                # Append to the end.
                req_index = None
            self.input_batch.add_request(req_state, req_index)

        # Condense the batched states if there are empty indices.
        if removed_req_indices:
            self.input_batch.condense(removed_req_indices)
459

460
461
        if batch_changed:
            self.input_batch.refresh_sampling_metadata()
462

463
    def _prepare_inputs(
464
465
        self,
        scheduler_output: "SchedulerOutput",
466
467
    ) -> tuple[FlashAttentionMetadata, torch.Tensor,
               Optional[SpecDecodeMetadata]]:
468
469
470
471
472
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        assert total_num_scheduled_tokens > 0
        num_reqs = self.input_batch.num_reqs
        assert num_reqs > 0

473
474
475
        # Some attention backends (namely MLA) may want to separate requests
        # based on if the attention computation will be compute-bound or
        # memory-bound. This gives them a hook to do that.
476
477
478
479
        modified_batch = self.attn_metadata_builder.reorder_batch(
            self.input_batch, scheduler_output)
        if modified_batch:
            self.input_batch.refresh_sampling_metadata()
480

481
482
        # OPTIMIZATION: Start copying the block table first.
        # This way, we can overlap the copy with the following CPU operations.
483
        self.input_batch.block_table.commit(num_reqs)
484
485
486

        # Get the number of scheduled tokens for each request.
        # TODO: The Python loop can be slow. Optimize.
487
        num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
488
        max_num_scheduled_tokens = 0
489
        for i, req_id in enumerate(self.input_batch.req_ids):
490
            num_tokens = scheduler_output.num_scheduled_tokens[req_id]
491
            num_scheduled_tokens[i] = num_tokens
492
493
494
495
496
            max_num_scheduled_tokens = max(max_num_scheduled_tokens,
                                           num_tokens)

        # Get request indices.
        # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
497
498
        req_indices = np.repeat(self.arange_np[:num_reqs],
                                num_scheduled_tokens)
499
500
501

        # Get batched arange.
        # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
502
503
504
505
506
507
508
509
510
        # Equivalent to but faster than:
        # np.concatenate([np.arange(n) for n in num_scheduled_tokens])
        # Step 1. [2, 5, 3] -> [2, 7, 10]
        cu_num_tokens = np.cumsum(num_scheduled_tokens)
        # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
        cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
                                    num_scheduled_tokens)
        # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
511
512

        # Get positions.
513
        positions_np = self.positions_np[:total_num_scheduled_tokens]
514
515
516
517
        np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
               arange,
               out=positions_np)

518
519
        # Calculate M-RoPE positions.
        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
520
        if self.uses_mrope:
521
522
            self._calc_mrope_positions(scheduler_output)

523
524
525
526
        # Get token indices.
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
        # where M is the max_model_len.
527
528
        token_indices = (positions_np +
                         req_indices * self.input_batch.token_ids_cpu.shape[1])
529

530
531
532
533
        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
        torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
534
                           0,
535
536
                           torch.from_numpy(token_indices),
                           out=self.input_ids_cpu[:total_num_scheduled_tokens])
537
538

        # Calculate the slot mapping.
539
540
541
542
543
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
        # where K is the max_num_blocks_per_req and the block size is 2.
        # NOTE(woosuk): We can't simply use `token_indices // block_size` here
        # because M (max_model_len) is not necessarily divisible by block_size.
544
545
546
547
548
        block_table_indices = (req_indices * self.max_num_blocks_per_req +
                               positions_np // self.block_size)
        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
549
550
        block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
        block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
551
552
553
554
        block_offsets = positions_np % self.block_size
        np.add(block_numbers * self.block_size,
               block_offsets,
               out=self.slot_mapping_np[:total_num_scheduled_tokens])
555
556

        # Prepare the attention metadata.
557
        self.query_start_loc_np[0] = 0
558
        self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
559

560
561
562
        self.seq_lens_np[:num_reqs] = (
            self.input_batch.num_computed_tokens_cpu[:num_reqs] +
            num_scheduled_tokens)
563
564
565
566

        # Copy the tensors to the GPU.
        self.input_ids[:total_num_scheduled_tokens].copy_(
            self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
567
        if self.uses_mrope:
568
569
570
571
572
573
574
575
576
            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
            self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
                self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
                non_blocking=True)
        else:
            # Common case (1D positions)
            self.positions[:total_num_scheduled_tokens].copy_(
                self.positions_cpu[:total_num_scheduled_tokens],
                non_blocking=True)
577

578
579
580
581
582
583
584
585
        # Prepare for cascade attention if enabled & beneficial.
        common_prefix_len = 0
        if self.cascade_attn_enabled:
            common_prefix_len = self._compute_cascade_attn_prefix_len(
                num_scheduled_tokens,
                scheduler_output.num_common_prefix_blocks,
            )

586
587
        attn_metadata = self.attn_metadata_builder.build(
            num_reqs=num_reqs,
588
            num_actual_tokens=total_num_scheduled_tokens,
589
            max_query_len=max_num_scheduled_tokens,
590
            common_prefix_len=common_prefix_len,
591
        )
592

593
594
        use_spec_decode = len(
            scheduler_output.scheduled_spec_decode_tokens) > 0
595
        if not use_spec_decode:
596
597
598
599
600
            # NOTE(woosuk): Due to chunked prefills, the batch may contain
            # partial requests. While we should not sample any token
            # from these partial requests, we do so for simplicity.
            # We will ignore the sampled tokens from the partial requests.
            # TODO: Support prompt logprobs.
601
            logits_indices = attn_metadata.query_start_loc[1:] - 1
602
603
604
605
606
607
608
609
610
611
612
613
614
615
            spec_decode_metadata = None
        else:
            # Get the number of draft tokens for each request.
            # Iterate over the dictionary rather than all requests since not all
            # requests have draft tokens.
            num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
            for req_id, draft_token_ids in (
                    scheduler_output.scheduled_spec_decode_tokens.items()):
                req_idx = self.input_batch.req_id_to_index[req_id]
                num_draft_tokens[req_idx] = len(draft_token_ids)

            spec_decode_metadata = self._calc_spec_decode_metadata(
                num_draft_tokens, cu_num_tokens)
            logits_indices = spec_decode_metadata.logits_indices
616

617
618
619
620
        # Hot-Swap lora model
        if self.lora_config:
            self.set_active_loras(self.input_batch, num_scheduled_tokens)

621
        return attn_metadata, logits_indices, spec_decode_metadata
622

623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
    def _compute_cascade_attn_prefix_len(
        self,
        num_scheduled_tokens: np.ndarray,
        num_common_prefix_blocks: int,
    ) -> int:
        """Compute the length of the common prefix for cascade attention.

        NOTE(woosuk): The common prefix length returned by this function
        represents the length used specifically for cascade attention, not the
        actual number of tokens shared between requests. When cascade attention
        is disabled (use_cascade=False), this function returns 0 even if
        requests share common tokens. Additionally, the common prefix length is
        truncated to a multiple of the block size and may be further truncated
        due to implementation details explained below.

        Args:
            num_scheduled_tokens: Number of tokens scheduled per request.
            num_common_prefix_blocks: Number of shared KV cache blocks.

        Returns:
            int: Length of common prefix in tokens.
        """
        common_prefix_len = num_common_prefix_blocks * self.block_size
        if common_prefix_len == 0:
            # Common case.
            return 0

        # NOTE(woosuk): Cascade attention uses two attention kernels: one
        # for the common prefix and the other for the rest. For the first
        # kernel, we concatenate all the query tokens (possibly from
        # different requests) and treat them as if they are from the same
        # request. Then, we use bi-directional attention to process the
        # common prefix in the KV cache. Importantly, this means that the
        # first kernel does not do any masking.

        # Consider the following example:
        # Request 1's input query: [D, E, X]
        # Request 1's kv cache: [A, B, C, D, E, X]
        # Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
        # Request 2's input query: [E, Y]
        # Request 2's kv cache: [A, B, C, D, E, Y]
        # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])

        # If we use [A, B, C, D, E] as the common prefix, then the
        # first kernel will compute the bi-directional attention between
        # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
        # However, this is wrong because D in Request 1 should not attend to
        # E in the common prefix (i.e., we need masking).
        # To avoid this, [A, B, C, D] should be the common prefix.
        # That is, the common prefix should be capped by the minimum
        # num_computed_tokens among the requests, and plus one to include
        # the first token of the query.

        # In practice, we use [A, B, C] as the common prefix, instead of
        # [A, B, C, D] (i.e., the common prefix is capped by the minimum
        # num_computed_tokens, without plus one).
        # This is because of an implementation detail: We want to always
        # use two kernels for cascade attention. Let's imagine:
        # Request 3's input query: [D]
        # Request 3's kv cache: [A, B, C, D]
683
        # Request 3's num_computed_tokens: 3 (i.e., [A, B, C])
684
685
686
687
688
689
690
691
692
693
694
695
        # If we use [A, B, C, D] as the common prefix for Request 1-3,
        # then Request 3 will be processed only by the first kernel,
        # and the second kernel will get an empty input. While this is not
        # a fundamental problem, our current implementation does not support
        # this case.
        num_reqs = len(num_scheduled_tokens)
        common_prefix_len = min(
            common_prefix_len,
            self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
        # common_prefix_len should be a multiple of the block size.
        common_prefix_len = (common_prefix_len // self.block_size *
                             self.block_size)
696
        use_cascade = self.attn_backend.use_cascade_attention(
697
698
699
700
            common_prefix_len=common_prefix_len,
            query_lens=num_scheduled_tokens,
            num_query_heads=self.num_query_heads,
            num_kv_heads=self.num_kv_heads,
701
            use_alibi=self.use_alibi,
702
            use_sliding_window=self.window_size is not None,
703
704
705
706
            num_sms=self.num_sms,
        )
        return common_prefix_len if use_cascade else 0

707
708
    def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
        mrope_pos_ptr = 0
709
        for index, req_id in enumerate(self.input_batch.req_ids):
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
            req = self.requests[req_id]
            assert req.mrope_positions is not None

            num_computed_tokens = \
                self.input_batch.num_computed_tokens_cpu[index]
            num_scheduled_tokens = \
                scheduler_output.num_scheduled_tokens[req_id]
            num_prompt_tokens = len(req.prompt_token_ids)

            if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
                prompt_part_len = max(0,
                                      num_prompt_tokens - num_computed_tokens)
                completion_part_len = max(
                    0, num_scheduled_tokens - prompt_part_len)
            else:
                prompt_part_len = num_scheduled_tokens
                completion_part_len = 0

            assert num_scheduled_tokens == prompt_part_len + completion_part_len

            if prompt_part_len > 0:
                # prompt's mrope_positions are pre-computed
                dst_start = mrope_pos_ptr
                dst_end = mrope_pos_ptr + prompt_part_len
                src_start = num_computed_tokens
                src_end = num_computed_tokens + prompt_part_len

                self.mrope_positions_cpu[:, dst_start:dst_end] = \
                    req.mrope_positions[:,src_start:src_end]

                mrope_pos_ptr += prompt_part_len

            if completion_part_len > 0:
                # compute completion's mrope_positions on-the-fly
                dst_start = mrope_pos_ptr
                dst_end = mrope_pos_ptr + completion_part_len

                self.mrope_positions_cpu[:, dst_start:dst_end] = \
                    MRotaryEmbedding.get_next_input_positions_tensor(
                        req.mrope_position_delta,
                        context_len=num_computed_tokens +
                        prompt_part_len,
                        seq_len=num_computed_tokens +
                        prompt_part_len +
                        completion_part_len,
                    )

                mrope_pos_ptr += completion_part_len

759
760
    def _calc_spec_decode_metadata(
        self,
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
        num_draft_tokens: np.ndarray,
        cu_num_scheduled_tokens: np.ndarray,
    ) -> SpecDecodeMetadata:
        # Inputs:
        # cu_num_scheduled_tokens:  [  4, 104, 107, 207, 209]
        # num_draft_tokens:         [  3,   0,   2,   0,   1]
        # Outputs:
        # cu_num_draft_tokens:      [  3,   3,   5,   5,   6]
        # logits_indices:           [  0,   1,   2,   3, 103, 104, 105, 106,
        #                            206, 207, 208]
        # target_logits_indices:    [  0,   1,   2,   5,   6,   9]
        # bonus_logits_indices:     [  3,   4,   7,   8,  10]

        # Compute the logits indices.
        # [4, 1, 3, 1, 2]
        num_sampled_tokens = num_draft_tokens + 1
        # Step 1. [4, 5, 8, 9, 11]
        cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
        total_num_sampled_tokens = cu_num_sampled_tokens[-1]
        # Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
        cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens,
                                    num_sampled_tokens)
        # Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
        arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets
        # Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
        logits_indices = np.repeat(
            cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens)
        # Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
        logits_indices += arange

        # Compute the bonus logits indices.
        bonus_logits_indices = cu_num_sampled_tokens - 1

        # Compute the draft logits indices.
        # [3, 3, 5, 5, 6]
        cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
        total_num_draft_tokens = cu_num_draft_tokens[-1]
        # [0, 0, 0, 3, 3, 5]
        cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens,
                                    num_draft_tokens)
        # [0, 1, 2, 0, 1, 0]
        arange = self.arange_np[:total_num_draft_tokens] - cumsums_offsets
        # [0, 0, 0, 5, 5, 9]
        target_logits_indices = np.repeat(
            cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens)
        # [0, 1, 2, 5, 6, 9]
        target_logits_indices += arange

        # TODO: Optimize the CPU -> GPU copy.
        cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
            self.device, non_blocking=True)
        logits_indices = torch.from_numpy(logits_indices).to(self.device,
                                                             non_blocking=True)
        target_logits_indices = torch.from_numpy(target_logits_indices).to(
            self.device, non_blocking=True)
        bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
817
818
            self.device, non_blocking=True)

819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
        # Compute the draft token ids.
        # draft_token_indices:      [  1,   2,   3, 105, 106, 208]
        draft_token_ids = self.input_ids[logits_indices]
        draft_token_ids = draft_token_ids[target_logits_indices + 1]

        metadata = SpecDecodeMetadata(
            draft_token_ids=draft_token_ids,
            num_draft_tokens=num_draft_tokens.tolist(),
            cu_num_draft_tokens=cu_num_draft_tokens,
            target_logits_indices=target_logits_indices,
            bonus_logits_indices=bonus_logits_indices,
            logits_indices=logits_indices,
        )
        return metadata

834
835
836
837
838
839
    def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
        scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
        if not scheduled_encoder_inputs:
            return

        # Batch the multi-modal inputs.
840
841
        mm_inputs: list[MultiModalKwargs] = []
        req_input_ids: list[tuple[str, int]] = []
842
843
844
845
846
        for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
            req_state = self.requests[req_id]
            for input_id in encoder_input_ids:
                mm_inputs.append(req_state.mm_inputs[input_id])
                req_input_ids.append((req_id, input_id))
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872

        # Batch mm inputs as much as we can: if a request in the batch has
        # multiple modalities or a different modality than the previous one,
        # we process it separately to preserve item order.
        # FIXME(ywang96): This is a hacky way to deal with multiple modalities
        # in the same batch while still being able to benefit from batching
        # multimodal inputs. The proper solution should be reordering the
        # encoder outputs.
        grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)

        encoder_outputs = []
        for grouped_mm_inputs in grouped_mm_inputs_list:
            batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
            batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
                                                           device=self.device)

            # Run the encoder.
            # `curr_group_outputs` is either of the following:
            # 1. A tensor of shape (num_items, feature_size, hidden_size)
            # in case feature_size is fixed across all multimodal items.
            # 2. A list or tuple (length: num_items) of tensors, each of shape
            # (feature_size, hidden_size) in case the feature size is dynamic
            # depending on the input multimodal items.
            curr_group_outputs = self.model.get_multimodal_embeddings(
                **batched_mm_inputs)

873
874
875
876
877
            sanity_check_mm_encoder_outputs(
                curr_group_outputs,
                expected_num_items=len(grouped_mm_inputs),
            )

878
879
            for output in curr_group_outputs:
                encoder_outputs.append(output)
880
881
882
883
884
885
886
887
888
889

        # Cache the encoder outputs.
        for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
            if req_id not in self.encoder_cache:
                self.encoder_cache[req_id] = {}
            self.encoder_cache[req_id][input_id] = output

    def _gather_encoder_outputs(
        self,
        scheduler_output: "SchedulerOutput",
890
891
    ) -> list[torch.Tensor]:
        encoder_outputs: list[torch.Tensor] = []
892
        for req_id in self.input_batch.req_ids:
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
                req_id]
            req_state = self.requests[req_id]
            num_computed_tokens = req_state.num_computed_tokens
            mm_positions = req_state.mm_positions
            for i, pos_info in enumerate(mm_positions):
                start_pos = pos_info["offset"]
                num_encoder_tokens = pos_info["length"]

                # The encoder output is needed if the two ranges overlap:
                # [num_computed_tokens,
                #  num_computed_tokens + num_scheduled_tokens) and
                # [start_pos, start_pos + num_encoder_tokens)
                if start_pos >= num_computed_tokens + num_scheduled_tokens:
                    # The encoder output is not needed in this step.
                    break
                if start_pos + num_encoder_tokens <= num_computed_tokens:
                    # The encoder output is already processed and stored
                    # in the decoder's KV cache.
                    continue

                start_idx = max(num_computed_tokens - start_pos, 0)
                end_idx = min(
                    num_computed_tokens - start_pos + num_scheduled_tokens,
                    num_encoder_tokens)
                assert start_idx < end_idx
                assert req_id in self.encoder_cache
                assert i in self.encoder_cache[req_id]
                encoder_output = self.encoder_cache[req_id][i]
                encoder_outputs.append(encoder_output[start_idx:end_idx])
        return encoder_outputs

925
926
927
    def get_model(self) -> nn.Module:
        return self.model

928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
    def apply_grammar_bitmask(
        self,
        scheduler_output: "SchedulerOutput",
        logits: torch.Tensor,
    ):
        # Serialization of np.ndarray is much more efficient than a tensor,
        # so we receive it in that format.
        grammar_bitmask = scheduler_output.grammar_bitmask
        if grammar_bitmask is None:
            return

        # We receive the structured output bitmask from the scheduler, but the
        # indices of the requests in the batch may not match the indices of
        # the bitmask since the scheduler doesn't know how the gpu runner is
        # ordering the requests in the batch. We need to sort the bitmask to
        # match the order of the requests used here.
        struct_out_req_batch_indices: dict[str, int] = {}
        indices_match = True
        for req_id in self.input_batch.req_ids:
            mask_index = scheduler_output.structured_output_request_ids.get(
                req_id)
            if mask_index is None:
                # not a structured output request
                continue
            batch_index = self.input_batch.req_id_to_index[req_id]
            if batch_index != mask_index:
                indices_match = False
            struct_out_req_batch_indices[req_id] = batch_index

        if not indices_match:
            # Sort the bitmask to match the order of the requests
            sorted_bitmask = np.zeros_like(grammar_bitmask)
            for req_id, batch_index in struct_out_req_batch_indices.items():
                orig_index = scheduler_output.structured_output_request_ids[
                    req_id]
                sorted_bitmask[batch_index] = grammar_bitmask[orig_index]
            grammar_bitmask = sorted_bitmask

        grammar_bitmask = torch.from_numpy(grammar_bitmask)

        # TODO: compatibility with spec decode
        xgr.apply_token_bitmask_inplace(
            logits,
            grammar_bitmask.to(self.device, non_blocking=True),
            indices=list(struct_out_req_batch_indices.values()),
        )

975
976
977
978
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
979
        intermediate_tensors: Optional[IntermediateTensors] = None,
980
    ) -> Union[ModelRunnerOutput, torch.Tensor]:
981
        self._update_states(scheduler_output)
982
983
984
        if not scheduler_output.total_num_scheduled_tokens:
            # Return empty ModelRunnerOuptut if there's no work to do.
            return EMPTY_MODEL_RUNNER_OUTPUT
985

986
987
988
989
990
991
        if self.is_multimodal_model:
            # Run the multimodal encoder if any.
            self._execute_encoder(scheduler_output)
            encoder_outputs = self._gather_encoder_outputs(scheduler_output)
        else:
            encoder_outputs = []
992
993

        # Prepare the decoder inputs.
994
995
        attn_metadata, logits_indices, spec_decode_metadata = (
            self._prepare_inputs(scheduler_output))
996
997
998
999
1000
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        if (self.use_cuda_graph
                and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
            # Use piecewise CUDA graphs.
            # Add padding to the batch size.
1001
            num_input_tokens = self.vllm_config.pad_for_cudagraph(
1002
1003
1004
1005
                num_scheduled_tokens)
        else:
            # Eager mode.
            num_input_tokens = num_scheduled_tokens
1006
1007
        attn_metadata.num_input_tokens = num_input_tokens

1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
        if self.is_multimodal_model:
            # NOTE(woosuk): To unify token ids and soft tokens (vision
            # embeddings), we always use embeddings (rather than token ids)
            # as input to the multimodal model, even when the input is text.
            input_ids = self.input_ids[:num_scheduled_tokens]
            if encoder_outputs:
                inputs_embeds = self.model.get_input_embeddings(
                    input_ids, encoder_outputs)
            else:
                inputs_embeds = self.model.get_input_embeddings(input_ids)
            # TODO(woosuk): Avoid the copy. Optimize.
            self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
            input_ids = None
1022
        else:
1023
1024
1025
1026
1027
1028
            # For text-only models, we use token ids as input.
            # While it is possible to use embeddings as input just like the
            # multimodal models, it is not desirable for performance since
            # then the embedding layer is not included in the CUDA graph.
            input_ids = self.input_ids[:num_input_tokens]
            inputs_embeds = None
1029
1030
1031
1032
        if self.uses_mrope:
            positions = self.mrope_positions[:, :num_input_tokens]
        else:
            positions = self.positions[:num_input_tokens]
1033

1034
1035
1036
        if get_pp_group().is_first_rank:
            intermediate_tensors = None
        else:
1037
1038
1039
1040
1041
            assert intermediate_tensors is not None
            assert self.intermediate_tensors is not None
            for k, v in intermediate_tensors.items():
                self.intermediate_tensors[k][:num_input_tokens].copy_(
                    v[:num_input_tokens], non_blocking=True)
1042
1043
1044
1045
1046
            intermediate_tensors = IntermediateTensors({
                k: v[:num_input_tokens]
                for k, v in self.intermediate_tensors.items()
            })

1047
1048
        # Run the decoder.
        # Use persistent buffers for CUDA graphs.
1049
        with set_forward_context(attn_metadata, self.vllm_config):
1050
            hidden_states = self.model(
1051
                input_ids=input_ids,
1052
                positions=positions,
1053
                intermediate_tensors=intermediate_tensors,
1054
                inputs_embeds=inputs_embeds,
1055
            )
1056
        if not get_pp_group().is_last_rank:
1057
            # For mid-pipeline stages, return the hidden states.
1058
            return hidden_states
1059

1060
        hidden_states = hidden_states[:num_scheduled_tokens]
1061
1062
        sample_hidden_states = hidden_states[logits_indices]
        logits = self.model.compute_logits(sample_hidden_states, None)
1063

1064
1065
1066
1067
        # Apply structured output bitmasks if present
        if scheduler_output.grammar_bitmask is not None:
            self.apply_grammar_bitmask(scheduler_output, logits)

1068
        # Sample the next token and get logprobs if needed.
1069
        sampling_metadata = self.input_batch.sampling_metadata
1070
        if spec_decode_metadata is None:
1071
1072
1073
1074
1075
            sampler_output = self.model.sample(
                logits=logits,
                sampling_metadata=sampling_metadata,
            )
        else:
1076
1077
1078
1079
            # When indexing with a tensor (bonus_logits_indices), PyTorch
            # creates a new tensor with separate storage from the original
            # logits tensor. This means any in-place operations on bonus_logits
            # won't affect the original logits tensor.
1080
            bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
1081
            sampler_output = self.model.sample(
1082
                logits=bonus_logits,
1083
1084
1085
                sampling_metadata=sampling_metadata,
            )
            bonus_token_ids = sampler_output.sampled_token_ids
1086

1087
1088
1089
            # Just like `bonus_logits`, `target_logits` is a new tensor with
            # separate storage from the original `logits` tensor. Therefore,
            # it is safe to update `target_logits` in place.
1090
            target_logits = logits[spec_decode_metadata.target_logits_indices]
1091
            output_token_ids = self.rejection_sampler(
1092
                spec_decode_metadata,
1093
                None,  # draft_probs
1094
                target_logits,
1095
                bonus_token_ids,
1096
1097
                sampling_metadata,
            )
1098
            sampler_output.sampled_token_ids = output_token_ids
1099
1100
1101

        # TODO(woosuk): The following loop can be slow since it iterates over
        # the requests one by one. Optimize.
1102
1103
        discard_sampled_tokens_req_indices = []
        for i, req_id in enumerate(self.input_batch.req_ids):
1104
1105
1106
            req_state = self.requests[req_id]
            seq_len = (req_state.num_computed_tokens +
                       scheduler_output.num_scheduled_tokens[req_id])
1107
            if seq_len < req_state.num_tokens:
1108
                # Ignore the sampled token for partial prefills.
1109
                # Rewind the generator state as if the token was not sampled.
1110
                # This relies on cuda-specific torch-internal impl details
1111
1112
1113
1114
1115
1116
                generator = self.input_batch.generators.get(i)
                if generator is not None:
                    generator.set_offset(generator.get_offset() - 4)
                # Record the index of the request that should not be sampled,
                # so that we could clear the sampled tokens before returning.
                discard_sampled_tokens_req_indices.append(i)
1117

1118
1119
        # NOTE: GPU -> CPU Sync happens here.
        # Move as many CPU operations as possible before this sync point.
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
        logprobs_tensors = sampler_output.logprobs_tensors
        logprobs_lists = logprobs_tensors.tolists() \
            if logprobs_tensors is not None else None

        # Compute prompt logprobs if needed.
        prompt_logprobs_dict = self._get_prompt_logprobs_dict(
            hidden_states,
            scheduler_output,
        )

1130
        # Get the valid generated tokens.
1131
1132
1133
        sampled_token_ids = sampler_output.sampled_token_ids
        max_gen_len = sampled_token_ids.shape[-1]
        if max_gen_len == 1:
1134
            # No spec decode tokens.
1135
1136
            valid_sampled_token_ids = sampled_token_ids.tolist()
        else:
1137
            # Includes spec decode tokens.
1138
            valid_sampled_token_ids = self.rejection_sampler.parse_output(
1139
1140
1141
                sampled_token_ids,
                self.input_batch.vocab_size,
            )
1142
1143
1144
        # Mask out the sampled tokens that should not be sampled.
        for i in discard_sampled_tokens_req_indices:
            valid_sampled_token_ids[i].clear()
1145

1146
1147
1148
1149
        if not self.use_spec_decode:
            spec_token_ids = None
        else:
            spec_token_ids = self.generate_draft_token_ids(
1150
                valid_sampled_token_ids, sampling_metadata)
1151

1152
        return ModelRunnerOutput(
1153
            req_ids=self.input_batch.req_ids,
1154
            req_id_to_index=self.input_batch.req_id_to_index,
1155
            sampled_token_ids=valid_sampled_token_ids,
1156
            spec_token_ids=spec_token_ids,
1157
1158
            logprobs=logprobs_lists,
            prompt_logprobs_dict=prompt_logprobs_dict,
1159
1160
        )

1161
1162
    def generate_draft_token_ids(
        self,
1163
        sampled_token_ids: list[list[int]],
1164
        sampling_metadata: SamplingMetadata,
1165
    ) -> list[list[int]]:
1166
        # TODO(woosuk): Optimize.
1167
        draft_token_ids: list[list[int]] = []
1168
1169
1170
        for i, sampled_ids in enumerate(sampled_token_ids):
            num_sampled_ids = len(sampled_ids)
            if not num_sampled_ids:
1171
1172
1173
1174
                # Skip speculative decoding.
                draft_token_ids.append([])
                continue

1175
1176
1177
1178
1179
1180
            # Skip requests that require top-p, top-k, etc.
            req_id = self.input_batch.req_ids[i]
            if not is_spec_decode_supported(req_id, self.input_batch):
                draft_token_ids.append([])
                continue

1181
1182
            # Add sampled_token_ids to token_ids_cpu.
            start_idx = self.input_batch.num_tokens_no_spec[i]
1183
1184
            end_idx = start_idx + num_sampled_ids
            self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
1185
1186
            drafter_output = self.drafter.propose(
                self.input_batch.token_ids_cpu[i, :end_idx],
1187
                self.speculative_config.prompt_lookup_min,
1188
                self.speculative_config.prompt_lookup_max,
1189
1190
1191
1192
1193
1194
1195
1196
                self.speculative_config.num_speculative_tokens,
            )
            if drafter_output is None or len(drafter_output) == 0:
                draft_token_ids.append([])
            else:
                draft_token_ids.append(drafter_output.tolist())
        return draft_token_ids

1197
1198
1199
    def load_model(self) -> None:
        logger.info("Starting to load model %s...", self.model_config.model)
        with DeviceMemoryProfiler() as m:  # noqa: SIM117
1200
            time_before_load = time.perf_counter()
Joe Runde's avatar
Joe Runde committed
1201
            self.model = get_model(vllm_config=self.vllm_config)
1202
1203
1204
1205
1206
1207
            if self.lora_config:
                self.model = self.load_lora_model(self.model,
                                                  self.model_config,
                                                  self.scheduler_config,
                                                  self.lora_config,
                                                  self.device)
1208
            time_after_load = time.perf_counter()
1209
        self.model_memory_usage = m.consumed_memory
1210
1211
        logger.info("Model loading took %.4f GiB and %.6f seconds",
                    self.model_memory_usage / GiB_bytes,
1212
                    time_after_load - time_before_load)
1213

1214
1215
1216
1217
    def _get_prompt_logprobs_dict(
        self,
        hidden_states: torch.Tensor,
        scheduler_output: "SchedulerOutput",
1218
    ) -> dict[str, Optional[LogprobsTensors]]:
1219
1220
1221
1222
        num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
        if not num_prompt_logprobs_dict:
            return {}

1223
        in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
1224
        prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238

        # Since prompt logprobs are a rare feature, prioritize simple,
        # maintainable loop over optimal performance.
        completed_prefill_reqs = []
        for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():

            num_tokens = scheduler_output.num_scheduled_tokens[req_id]

            # Get metadata for this request.
            request = self.requests[req_id]
            num_prompt_tokens = len(request.prompt_token_ids)
            prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
                self.device, non_blocking=True)

1239
1240
1241
1242
1243
1244
1245
1246
1247
            # Set up target LogprobsTensors object.
            logprobs_tensors = in_progress_dict.get(req_id)
            if not logprobs_tensors:
                # Create empty logprobs CPU tensors for the entire prompt.
                # If chunked, we'll copy in slice by slice.
                logprobs_tensors = LogprobsTensors.empty_cpu(
                    num_prompt_tokens - 1, num_prompt_logprobs + 1)
                in_progress_dict[req_id] = logprobs_tensors

1248
            # Determine number of logits to retrieve.
1249
1250
            start_idx = request.num_computed_tokens
            start_tok = start_idx + 1
1251
            num_remaining_tokens = num_prompt_tokens - start_tok
1252
            if num_tokens <= num_remaining_tokens:
1253
                # This is a chunk, more tokens remain.
1254
1255
1256
                # In the == case, there are no more prompt logprobs to produce
                # but we want to defer returning them to the next step where we
                # have new generated tokens to return.
1257
1258
1259
1260
1261
                num_logits = num_tokens
            else:
                # This is the last chunk of prompt tokens to return.
                num_logits = num_remaining_tokens
                completed_prefill_reqs.append(req_id)
1262
1263
1264
1265
1266
1267
1268
                prompt_logprobs_dict[req_id] = logprobs_tensors

            if num_logits <= 0:
                # This can happen for the final chunk if we prefilled exactly
                # (num_prompt_tokens - 1) tokens for this request in the prior
                # step. There are no more prompt logprobs to produce.
                continue
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288

            # Get the logits corresponding to this req's prompt tokens.
            # If this is a partial request (i.e. chunked prefill),
            # then there is prompt logprob generated for each index.
            req_idx = self.input_batch.req_id_to_index[req_id]
            offset = self.query_start_loc_np[req_idx].item()
            prompt_hidden_states = hidden_states[offset:offset + num_logits]
            logits = self.model.compute_logits(prompt_hidden_states, None)

            # Get the "target" tokens for each index. For prompt at index i,
            # the token at prompt index i+1 is the "sampled" token we want
            # to gather the logprob for.
            tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]

            # Compute prompt logprobs.
            logprobs = self.model.sampler.compute_logprobs(logits)
            token_ids, logprobs, ranks = self.model.sampler.gather_logprobs(
                logprobs, num_prompt_logprobs, tgt_token_ids)

            # Transfer GPU->CPU async.
1289
1290
1291
1292
1293
1294
1295
            chunk_slice = slice(start_idx, start_idx + num_logits)
            logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
                token_ids, non_blocking=True)
            logprobs_tensors.logprobs[chunk_slice].copy_(logprobs,
                                                         non_blocking=True)
            logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
                ranks, non_blocking=True)
1296
1297
1298
1299
1300

        # Remove requests that have completed prefill from the batch
        # num_prompt_logprobs_dict.
        for req_id in completed_prefill_reqs:
            del num_prompt_logprobs_dict[req_id]
1301
            del in_progress_dict[req_id]
1302
1303

        # Must synchronize the non-blocking GPU->CPU transfers.
1304
1305
        if prompt_logprobs_dict:
            torch.cuda.synchronize()
1306
1307
1308

        return prompt_logprobs_dict

1309
1310
1311
1312
1313
    @torch.inference_mode()
    def _dummy_run(
        self,
        num_tokens: int,
    ) -> torch.Tensor:
1314

1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
        # Set num_scheduled_tokens based on num_tokens and max_num_seqs
        # for dummy run with LoRA so that the num_reqs collectively
        # has num_tokens in total.
        assert num_tokens <= self.scheduler_config.max_num_batched_tokens
        max_num_reqs = self.scheduler_config.max_num_seqs
        num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens
        min_tokens_per_req = num_tokens // num_reqs
        num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
        num_scheduled_tokens_list[-1] += num_tokens % num_reqs
        assert sum(num_scheduled_tokens_list) == num_tokens
        assert len(num_scheduled_tokens_list) == num_reqs
        num_scheduled_tokens = np.array(num_scheduled_tokens_list,
                                        dtype=np.int32)
1328

1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
        with self.maybe_dummy_run_with_lora(self.lora_config,
                                            num_scheduled_tokens):
            model = self.model
            if self.is_multimodal_model:
                input_ids = None
                inputs_embeds = self.inputs_embeds[:num_tokens]
            else:
                input_ids = self.input_ids[:num_tokens]
                inputs_embeds = None
            if self.uses_mrope:
                positions = self.mrope_positions[:, :num_tokens]
            else:
                positions = self.positions[:num_tokens]

            if get_pp_group().is_first_rank:
                intermediate_tensors = None
            else:
                if self.intermediate_tensors is None:
                    self.intermediate_tensors = (
                        self.model.make_empty_intermediate_tensors(
                            batch_size=self.max_num_tokens,
                            dtype=self.model_config.dtype,
                            device=self.device))
                intermediate_tensors = IntermediateTensors({
                    k: v[:num_tokens]
                    for k, v in self.intermediate_tensors.items()
                })

            with set_forward_context(None,
                                     self.vllm_config,
                                     num_tokens=num_tokens):
                hidden_states = model(
                    input_ids=input_ids,
                    positions=positions,
                    intermediate_tensors=intermediate_tensors,
                    inputs_embeds=inputs_embeds,
                )

        logit_indices = np.cumsum(num_scheduled_tokens) - 1
        return hidden_states[logit_indices]

    @torch.inference_mode()
    def _dummy_sampler_run(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:

        logits = self.model.compute_logits(hidden_states, None)
        num_reqs = logits.size(0)

        dummy_tensors = lambda v: torch.full(
            (num_reqs, ), v, device=self.device)

        dummy_metadata = SamplingMetadata(
            temperature=dummy_tensors(0.5),
            all_greedy=False,
            all_random=False,
            top_p=dummy_tensors(0.9),
            top_k=dummy_tensors(logits.size(1) - 1),
            min_p=None,
            generators={},
            max_num_logprobs=None,
            no_penalties=True,
            prompt_token_ids=None,
            frequency_penalties=dummy_tensors(0.1),
            presence_penalties=dummy_tensors(0.1),
            repetition_penalties=dummy_tensors(0.1),
            output_token_ids=[[] for _ in range(num_reqs)],
            min_tokens={},
            logit_bias=[None for _ in range(num_reqs)],
            allowed_token_ids_mask=None,
            bad_words_token_ids={},
        )
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
        try:
            sampler_output = self.model.sample(
                logits=logits, sampling_metadata=dummy_metadata)
        except RuntimeError as e:
            if 'out of memory' in str(e):
                raise RuntimeError(
                    "CUDA out of memory occurred when warming up sampler with "
                    f"{num_reqs} dummy requests. Please try lowering "
                    "`max_num_seqs` or `gpu_memory_utilization` when "
                    "initializing the engine.") from e
            else:
                raise e
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
        if self.use_spec_decode:
            draft_token_ids = [[0] for _ in range(num_reqs)]
            dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy(
                draft_token_ids, self.device)

            num_tokens = sum(len(ids) for ids in draft_token_ids)
            # draft_probs = torch.randn(
            #     num_tokens, logits.shape[-1], device=self.device,
            #     dtype=logits.dtype)
            draft_probs = None
            target_logits = torch.randn(num_tokens,
                                        logits.shape[-1],
                                        device=self.device,
                                        dtype=logits.dtype)
            # NOTE(woosuk): Here, we should use int32 because the sampler uses
            # int32 for bonus_token_ids. If the dtype mismatches, re-compilation
            # will occur at runtime.
            bonus_token_ids = torch.zeros(num_reqs,
                                          device=self.device,
                                          dtype=torch.int32)
            self.rejection_sampler(
                dummy_spec_decode_metadata,
                draft_probs,
                target_logits,
                bonus_token_ids,
                dummy_metadata,
            )
1441
        return sampler_output
1442
1443

    def profile_run(self) -> None:
1444
        # Profile with multimodal encoder & encoder cache.
1445
1446
1447
        # TODO: handle encoder-decoder models once we support them.
        if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
                and self.encoder_cache_size > 0):
1448

1449
            # NOTE: Currently model is profiled with a single non-text
1450
1451
            # modality with the max possible input tokens even when
            # it supports multiple.
1452
1453
            max_tokens_by_modality_dict = self.mm_registry \
                .get_max_tokens_per_item_by_nonzero_modality(self.model_config)
1454
1455
1456
1457
            dummy_data_modality, max_tokens_per_mm_item = max(
                max_tokens_by_modality_dict.items(), key=lambda item: item[1])

            # Check how many items of this modality can be supported by
1458
1459
1460
1461
1462
1463
            # the encoder budget.
            encoder_budget = min(self.max_num_encoder_input_tokens,
                                 self.encoder_cache_size)

            max_num_mm_items_encoder_budget = cdiv(encoder_budget,
                                                   max_tokens_per_mm_item)
1464
1465
1466

            # Check how many items of this modality can be supported by
            # the decoder budget.
1467
1468
            max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
                self.model_config)[dummy_data_modality]
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478

            # NOTE: We do not consider max_num_batched_tokens on purpose
            # because the multimodal embeddings can be generated in advance
            # and chunked prefilled.
            max_num_mm_items_decoder_budget = self.max_num_reqs * \
                max_mm_items_per_req

            max_num_mm_items = min(max_num_mm_items_encoder_budget,
                                   max_num_mm_items_decoder_budget)

1479
1480
1481
1482
1483
1484
            logger.info(
                "Encoder cache will be initialized with a budget of %s tokens,"
                " and profiled with %s %s items of the maximum feature size.",
                encoder_budget, max_num_mm_items, dummy_data_modality)

            # Create dummy batch of multimodal inputs.
1485
            dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data(
1486
1487
                model_config=self.model_config,
                seq_len=self.max_num_tokens,
1488
1489
1490
1491
                mm_counts={
                    dummy_data_modality: 1
                },
            ).multi_modal_data
1492

1493
            batched_dummy_mm_inputs = MultiModalKwargs.batch(
1494
                [dummy_mm_kwargs] * max_num_mm_items)
1495
1496
1497
1498
1499
1500
            batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
                batched_dummy_mm_inputs, device=self.device)

            # Run multimodal encoder.
            dummy_encoder_outputs = self.model.get_multimodal_embeddings(
                **batched_dummy_mm_inputs)
1501
1502
1503
1504
1505

            sanity_check_mm_encoder_outputs(
                dummy_encoder_outputs,
                expected_num_items=max_num_mm_items,
            )
1506
1507
1508
1509

            # Cache the dummy encoder outputs.
            self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))

1510
1511
1512
1513
1514
1515
1516
1517
        hidden_states = self._dummy_run(self.max_num_tokens)
        if get_pp_group().is_last_rank:
            sampler_output = self._dummy_sampler_run(hidden_states)
        else:
            sampler_output = None
        torch.cuda.synchronize()
        del hidden_states, sampler_output
        self.encoder_cache.clear()
1518
        gc.collect()
1519
1520

    def capture_model(self) -> None:
1521
1522
        if not self.use_cuda_graph:
            logger.warning(
1523
                "Skipping CUDA graph capture. Please add "
1524
                "-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE)
1525
1526
1527
1528
1529
            return

        start_time = time.perf_counter()
        start_free_gpu_memory = torch.cuda.mem_get_info()[0]

1530
1531
1532
        # Trigger CUDA graph capture for specific shapes.
        # Capture the large shapes first so that the smaller shapes
        # can reuse the memory pool allocated for the large shapes.
1533
        with graph_capture(device=self.device):
1534
            for num_tokens in reversed(self.cudagraph_batch_sizes):
1535
1536
                for _ in range(self.vllm_config.compilation_config.
                               cudagraph_num_of_warmups):
1537
1538
                    self._dummy_run(num_tokens)
                self._dummy_run(num_tokens)
1539
1540
1541
1542
1543
1544
1545
1546

        end_time = time.perf_counter()
        end_free_gpu_memory = torch.cuda.mem_get_info()[0]
        elapsed_time = end_time - start_time
        cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
        # This usually takes 5~20 seconds.
        logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
                    elapsed_time, cuda_graph_size / (1 << 30))
1547

1548
1549
1550
1551
    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize KV cache based on `kv_cache_config`.
        Args:
1552
            kv_cache_config: Configuration for the KV cache, including the KV
1553
1554
            cache size of each layer
        """
1555
        if len(kv_cache_config.kv_cache_groups) > 1:
1556
1557
1558
1559
            raise NotImplementedError(
                "Hybrid models with more than one KV cache type are not "
                "supported yet.")

1560
        kv_caches: dict[str, torch.Tensor] = {}
1561

1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
        for kv_cache_group in kv_cache_config.kv_cache_groups:
            kv_cache_spec = kv_cache_group.kv_cache_spec
            for layer_name in kv_cache_group.layer_names:
                tensor_config = kv_cache_config.tensors[layer_name]
                assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
                num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
                # `num_blocks` is the number of blocks the model runner can use.
                # `kv_cache_config.num_blocks` is the number of blocks that
                # KVCacheManager may allocate.
                # Since different GPUs may have different number of layers and
                # different memory capacities, `num_blocks` can be different on
                # different GPUs, and `kv_cache_config.num_blocks` is set to
                # the min of all `num_blocks`. Verify it here.
                assert num_blocks >= kv_cache_config.num_blocks
1576
                if isinstance(kv_cache_spec, AttentionSpec):
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
                    kv_cache_shape = self.attn_backend.get_kv_cache_shape(
                        num_blocks, kv_cache_spec.block_size,
                        kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
                    dtype = kv_cache_spec.dtype
                    kv_caches[layer_name] = torch.zeros(kv_cache_shape,
                                                        dtype=dtype,
                                                        device=self.device)
                else:
                    # TODO: add new branches when introducing more types of
                    # KV cache specs.
                    raise ValueError("Unknown KV cache spec type.")
1588

1589
        bind_kv_cache(
1590
            kv_caches,
1591
            self.vllm_config.compilation_config.static_forward_context,
1592
1593
            self.kv_caches)

1594
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
1595
        """
1596
        Generates the KVCacheSpec by parsing the kv cache format from each
1597
1598
        Attention module in the static forward context.
        Returns:
1599
            KVCacheSpec: A dictionary mapping layer names to their KV cache
1600
1601
1602
1603
1604
            format. Layers that do not need KV cache are not included.
        """

        forward_ctx = self.vllm_config.compilation_config.static_forward_context
        block_size = self.vllm_config.cache_config.block_size
1605
        use_mla = self.vllm_config.model_config.use_mla
1606
        kv_cache_spec: dict[str, KVCacheSpec] = {}
1607
        for layer_name, attn_module in forward_ctx.items():
1608
1609
1610
            if isinstance(attn_module, FusedMoE):
                continue

1611
            # TODO: Support other attention modules, e.g., sliding window,
1612
            # cross-attention
1613
1614
            assert isinstance(attn_module, Attention)
            if attn_module.attn_type == AttentionType.DECODER:
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
                if attn_module.sliding_window is not None:
                    kv_cache_spec[layer_name] = SlidingWindowSpec(
                        block_size=block_size,
                        num_kv_heads=attn_module.num_kv_heads,
                        head_size=attn_module.head_size,
                        dtype=self.kv_cache_dtype,
                        sliding_window=attn_module.sliding_window,
                        use_mla=use_mla)
                else:
                    kv_cache_spec[layer_name] = FullAttentionSpec(
                        block_size=block_size,
                        num_kv_heads=attn_module.num_kv_heads,
                        head_size=attn_module.head_size,
                        dtype=self.kv_cache_dtype,
                        use_mla=use_mla)
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
            elif attn_module.attn_type in (AttentionType.ENCODER,
                                           AttentionType.ENCODER_ONLY):
                # encoder-only attention does not need KV cache.
                continue
            elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
                raise NotImplementedError
            else:
                raise ValueError(
                    f"Unknown attention type: {attn_module.attn_type}")

        return kv_cache_spec