gpu_model_runner.py 79.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import gc
4
import time
5
import weakref
6
from typing import TYPE_CHECKING, Optional, Union
7
8
9
10
11
12

import numpy as np
import torch
import torch.distributed
import torch.nn as nn

13
from vllm.attention import AttentionType, get_attn_backend
14
from vllm.attention.layer import Attention
15
from vllm.config import CompilationLevel, VllmConfig
16
17
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
                                          has_kv_transfer_group)
18
from vllm.distributed.parallel_state import get_pp_group, graph_capture
19
20
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
21
from vllm.model_executor.layers.fused_moe import FusedMoE
22
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
23
from vllm.model_executor.model_loader import get_model
24
25
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
26
from vllm.multimodal.utils import group_mm_inputs_by_modality
27
from vllm.sampling_params import SamplingType
28
from vllm.sequence import IntermediateTensors
29
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
30
31
                        GiB_bytes, LayerBlockType, LazyLoader, cdiv,
                        check_use_alibi, is_pin_memory_available)
32
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
33
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
34
35
36
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
                                        KVCacheConfig, KVCacheSpec,
                                        SlidingWindowSpec)
37
38
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
                             ModelRunnerOutput)
39
from vllm.v1.sample.metadata import SamplingMetadata
40
from vllm.v1.sample.rejection_sampler import RejectionSampler
41
from vllm.v1.spec_decode.eagle import EagleProposer
42
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
43
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
44
from vllm.v1.spec_decode.utils import is_spec_decode_supported
45
from vllm.v1.utils import bind_kv_cache
46
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
47
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
48

49
50
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
                    scatter_mm_placeholders)
51

52
if TYPE_CHECKING:
53
54
    import xgrammar as xgr

55
    from vllm.v1.core.sched.output import SchedulerOutput
56
57
else:
    xgr = LazyLoader("xgr", globals(), "xgrammar")
58
59
60
61

logger = init_logger(__name__)


62
class GPUModelRunner(LoRAModelRunnerMixin):
63
64
65

    def __init__(
        self,
66
        vllm_config: VllmConfig,
67
        device: torch.device,
68
    ):
69
70
71
72
73
74
75
76
77
78
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.prompt_adapter_config = vllm_config.prompt_adapter_config
        self.observability_config = vllm_config.observability_config
79

80
81
82
83
        from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
        set_cpu_offload_max_bytes(
            int(self.cache_config.cpu_offload_gb * 1024**3))

84
85
86
87
        model_config = self.model_config
        cache_config = self.cache_config
        scheduler_config = self.scheduler_config
        parallel_config = self.parallel_config
88
        self.device = device
89
90
91
92
93
94
95
96
        self.pin_memory = is_pin_memory_available()
        self.dtype = self.model_config.dtype
        if cache_config.cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        else:
            self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
                cache_config.cache_dtype]

97
98
        # NOTE(woosuk): sliding_window is None for models with interleaved
        # attention. Use interleaved_sliding_window instead.
99
        self.sliding_window = model_config.get_sliding_window()
100
101
102
103
104
105
        self.interleaved_sliding_window = getattr(
            model_config.hf_text_config, "interleaved_sliding_window", None)
        self.window_size = (self.sliding_window
                            or self.interleaved_sliding_window)

        self.is_multimodal_model = model_config.is_multimodal_model
106
107
108
109
        self.block_size = cache_config.block_size
        self.max_model_len = model_config.max_model_len
        self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
        self.max_num_tokens = scheduler_config.max_num_batched_tokens
110
        self.max_num_reqs = scheduler_config.max_num_seqs
111
112

        # Model-related.
113
114
        self.num_attn_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)
115
116
        self.num_query_heads = model_config.get_num_attention_heads(
            parallel_config)
117
118
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
        self.head_size = model_config.get_head_size()
119
        self.hidden_size = model_config.get_hidden_size()
120
        self.attention_chunk_size = model_config.attention_chunk_size
121

122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        self.attn_backend = get_attn_backend(
            self.head_size,
            self.dtype,
            self.kv_cache_dtype,
            self.block_size,
            self.model_config.is_attention_free,
            use_mla=self.model_config.use_mla,
        )
        if self.attn_backend is None:
            error_msg = (
                f"Error with get_att_backend: {self.head_size=}, "
                f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
                f"{self.model_config.is_attention_free=}, "
                f"{self.model_config.use_mla=}")
            logger.error(error_msg)
            raise NotImplementedError(
                "Non-Attention backend is not supported by V1 GPUModelRunner.")

        self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
            weakref.proxy(self))
142
        self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
143

144
        # Multi-modal data support
145
        self.mm_registry = MULTIMODAL_REGISTRY
146
        self.uses_mrope = model_config.uses_mrope
147

148
149
150
        encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
            model_config=model_config,
            scheduler_config=scheduler_config,
151
            mm_registry=self.mm_registry,
152
153
154
        )
        self.max_num_encoder_input_tokens = encoder_compute_budget
        self.encoder_cache_size = encoder_cache_size
155
156
157

        # Lazy initialization
        # self.model: nn.Module  # Set after load_model
158
        self.kv_caches: list[torch.Tensor] = []
159
        # req_id -> (input_id -> encoder_output)
160
        self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
161

162
163
164
        # Set up speculative decoding.
        self.use_spec_decode = False
        if self.speculative_config:
165
166
            self.use_spec_decode = True
            if get_pp_group().is_last_rank:
167
168
169
170
171
172
173
174
                if self.speculative_config.method == "ngram":
                    self.drafter = NgramProposer(self.vllm_config)
                elif self.speculative_config.method == "eagle":
                    self.drafter = EagleProposer(self.vllm_config,
                                                 self.device)  # type: ignore
                else:
                    raise ValueError("Unknown speculative decoding method: "
                                     f"{self.speculative_config.method}")
175
                self.rejection_sampler = RejectionSampler()
176

177
        # Request states.
178
        self.requests: dict[str, CachedRequestState] = {}
179
180
        # Persistent batch.
        self.input_batch = InputBatch(
181
            max_num_reqs=self.max_num_reqs,
182
183
184
185
            max_model_len=self.max_model_len,
            max_num_blocks_per_req=self.max_num_blocks_per_req,
            device=self.device,
            pin_memory=self.pin_memory,
186
            vocab_size=model_config.get_vocab_size(),
187
188
        )

189
        self.use_cuda_graph = (self.vllm_config.compilation_config.level
190
191
192
                               == CompilationLevel.PIECEWISE
                               and not self.model_config.enforce_eager)
        # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
193
194
195
196
        # The convention is different.
        # self.cudagraph_batch_sizes sorts in ascending order.
        # The batch sizes in the config are in descending order.
        self.cudagraph_batch_sizes = list(
197
198
            reversed(
                self.vllm_config.compilation_config.cudagraph_capture_sizes))
199

200
201
202
203
        # Cache the device properties.
        self.device_properties = torch.cuda.get_device_properties(self.device)
        self.num_sms = self.device_properties.multi_processor_count

204
205
206
207
        # Persistent buffers for CUDA graphs.
        self.input_ids = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int32,
                                     device=self.device)
208
209
210
        self.positions = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int64,
                                     device=self.device)
211
212
        # None in the first PP rank. The rest are set after load_model.
        self.intermediate_tensors: Optional[IntermediateTensors] = None
213
214

        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
215
        if self.uses_mrope:
Roger Wang's avatar
Roger Wang committed
216
217
218
219
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
220
221
222
223
224
225

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
Roger Wang's avatar
Roger Wang committed
226
            self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1),
227
228
                                               dtype=torch.int64,
                                               device=self.device)
Roger Wang's avatar
Roger Wang committed
229
230
231
232
233
            self.mrope_positions_cpu = torch.zeros(
                (3, self.max_num_tokens + 1),
                dtype=torch.int64,
                device="cpu",
                pin_memory=self.pin_memory)
234

235
236
237
        # Only relevant for models using ALiBi (e.g, MPT)
        self.use_alibi = check_use_alibi(model_config)

238
239
240
241
        self.inputs_embeds = torch.zeros(
            (self.max_num_tokens, self.hidden_size),
            dtype=self.dtype,
            device=self.device)
242

243
        # OPTIMIZATION: Cache the tensors rather than creating them every step.
244
        self.arange_np = np.arange(max(self.max_num_reqs + 1,
245
246
                                       self.max_model_len,
                                       self.max_num_tokens),
247
248
249
250
                                   dtype=np.int32)
        # NOTE(woosuk): These tensors are "stateless", i.e., they are literally
        # a faster version of creating a new tensor every time. Thus, we should
        # not make any assumptions about the values in these tensors.
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        self.input_ids_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int32,
                                         device="cpu",
                                         pin_memory=self.pin_memory)
        self.input_ids_np = self.input_ids_cpu.numpy()
        self.positions_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int64,
                                         device="cpu",
                                         pin_memory=self.pin_memory)
        self.positions_np = self.positions_cpu.numpy()
        self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
                                            dtype=torch.int32,
                                            device="cpu",
                                            pin_memory=self.pin_memory)
        self.slot_mapping_np = self.slot_mapping_cpu.numpy()
        self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
                                               dtype=torch.int32,
                                               device="cpu",
                                               pin_memory=self.pin_memory)
        self.query_start_loc_np = self.query_start_loc_cpu.numpy()
271
272
273
274
275
        self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
                                        dtype=torch.int32,
                                        device="cpu",
                                        pin_memory=self.pin_memory)
        self.seq_lens_np = self.seq_lens_cpu.numpy()
276

277
    def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
278
279
280
281
282
283
        """Update the cached states and the persistent batch with the scheduler
        output.

        The updated states are used by the `_prepare_inputs` function to create
        the input GPU tensors for the model.

284
285
        The SamplingMetadata is updated and copied to the GPU if there is a
        new/resumed/paused/finished request in the batch.
286
287
        """
        # Remove finished requests from the cached states.
288
289
        for req_id in scheduler_output.finished_req_ids:
            self.requests.pop(req_id, None)
290
            self.encoder_cache.pop(req_id, None)
291
292
293
294
295
296
        # Remove the finished requests from the persistent batch.
        # NOTE(woosuk): There could be an edge case where finished_req_ids and
        # scheduled_req_ids overlap. This happens when a request is aborted and
        # then resubmitted with the same ID. In this case, we treat them as two
        # distinct requests - clearing the cached states for the first request
        # and handling the second as a new request.
297
        removed_req_indices: list[int] = []
298
299
300
301
        for req_id in scheduler_output.finished_req_ids:
            req_index = self.input_batch.remove_request(req_id)
            if req_index is not None:
                removed_req_indices.append(req_index)
302
303
304
305
306
307
308
309

        # Free the cached encoder outputs.
        for req_id, input_id in scheduler_output.free_encoder_input_ids:
            encoder_outputs = self.encoder_cache.get(req_id)
            if encoder_outputs is not None:
                encoder_outputs.pop(input_id, None)
                if not encoder_outputs:
                    self.encoder_cache.pop(req_id, None)
310

311
312
313
314
315
316
317
318
319
320
321
322
323
        # Remove the unscheduled requests from the persistent batch.
        # NOTE(woosuk): The unscheduled requests are either preempted requests
        # or running requests that are not scheduled in this step. We remove
        # them from the persistent batch but keep their cached states since
        # they will be scheduled again sometime in the future.
        scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
        cached_req_ids = self.input_batch.req_id_to_index.keys()
        unscheduled_req_ids = cached_req_ids - scheduled_req_ids
        # NOTE(woosuk): The persistent batch optimization assumes that
        # consecutive batches contain mostly the same requests. If batches
        # have low request overlap (e.g., alternating between two distinct
        # sets of requests), this optimization becomes very inefficient.
        for req_id in unscheduled_req_ids:
324
            req_index = self.input_batch.remove_request(req_id)
325
326
            assert req_index is not None
            removed_req_indices.append(req_index)
327

328
        req_ids_to_add: list[str] = []
329
        # Add new requests to the cached states.
330
331
332
        for new_req_data in scheduler_output.scheduled_new_reqs:
            req_id = new_req_data.req_id
            sampling_params = new_req_data.sampling_params
333
            if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
334
335
336
337
338
                generator = torch.Generator(device=self.device)
                generator.manual_seed(sampling_params.seed)
            else:
                generator = None

339
340
            self.requests[req_id] = CachedRequestState(
                req_id=req_id,
341
342
343
344
                prompt_token_ids=new_req_data.prompt_token_ids,
                prompt=new_req_data.prompt,
                mm_inputs=new_req_data.mm_inputs,
                mm_positions=new_req_data.mm_positions,
345
346
                sampling_params=sampling_params,
                generator=generator,
347
348
                block_ids=new_req_data.block_ids,
                num_computed_tokens=new_req_data.num_computed_tokens,
349
                output_token_ids=[],
350
                lora_request=new_req_data.lora_request,
351
            )
352
353

            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
354
            if self.uses_mrope:
355
356
                image_grid_thw = []
                video_grid_thw = []
Roger Wang's avatar
Roger Wang committed
357
                second_per_grid_ts = []
358
359
                audio_feature_lengths = []
                use_audio_in_video = False
360
361
362
363
364
365
366
                for mm_input in self.requests[req_id].mm_inputs:
                    if mm_input.get("image_grid_thw") is not None:
                        image_grid_thw.extend(
                            mm_input["image_grid_thw"].tolist())
                    if mm_input.get("video_grid_thw") is not None:
                        video_grid_thw.extend(
                            mm_input["video_grid_thw"].tolist())
Roger Wang's avatar
Roger Wang committed
367
368
369
                    if mm_input.get("second_per_grid_ts") is not None:
                        second_per_grid_ts.extend(
                            mm_input["second_per_grid_ts"])
370
371
372
373
374
                    if mm_input.get("audio_feature_lengths") is not None:
                        audio_feature_lengths.extend(
                            mm_input["audio_feature_lengths"])
                    if mm_input.get("use_audio_in_video") is True:
                        use_audio_in_video = True
375
376
377
378
379
380
381

                hf_config = self.model_config.hf_config

                self.requests[req_id].mrope_positions, \
                    self.requests[req_id].mrope_position_delta = \
                    MRotaryEmbedding.get_input_positions_tensor(
                        self.requests[req_id].prompt_token_ids,
Roger Wang's avatar
Roger Wang committed
382
                        hf_config=hf_config,
383
384
                        image_grid_thw=image_grid_thw,
                        video_grid_thw=video_grid_thw,
Roger Wang's avatar
Roger Wang committed
385
                        second_per_grid_ts=second_per_grid_ts,
386
387
                        audio_feature_lengths=audio_feature_lengths,
                        use_audio_in_video=use_audio_in_video,
388
389
                    )

390
391
            req_ids_to_add.append(req_id)

392
393
394
        # Update the states of the running/resumed requests.
        for req_data in scheduler_output.scheduled_cached_reqs:
            req_id = req_data.req_id
395
396
            req_state = self.requests[req_id]

397
            # Update the cached states.
398
399
400
401
402
403
404
            num_computed_tokens = req_data.num_computed_tokens
            req_state.num_computed_tokens = num_computed_tokens
            # Add the sampled token(s) from the previous step (if any).
            # This doesn't include "unverified" tokens like spec decode tokens.
            num_new_tokens = (num_computed_tokens +
                              len(req_data.new_token_ids) -
                              req_state.num_tokens)
405
406
407
408
409
410
            if num_new_tokens == 1:
                # Avoid slicing list in most common case.
                req_state.output_token_ids.append(req_data.new_token_ids[-1])
            elif num_new_tokens > 0:
                req_state.output_token_ids.extend(
                    req_data.new_token_ids[-num_new_tokens:])
411
            # Update the block IDs.
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
            if not req_data.resumed_from_preemption:
                # Append the new blocks to the existing block IDs.
                req_state.block_ids.extend(req_data.new_block_ids)
            else:
                # The request is resumed from preemption.
                # Replace the existing block IDs with the new ones.
                req_state.block_ids = req_data.new_block_ids

            req_index = self.input_batch.req_id_to_index.get(req_id)
            if req_index is None:
                # The request is not in the persistent batch.
                # The request was either preempted and resumed later, or was not
                # scheduled in the previous step and needs to be added again.
                req_ids_to_add.append(req_id)
                continue

            # Update the persistent batch.
            self.input_batch.num_computed_tokens_cpu[req_index] = (
430
                num_computed_tokens)
431
432
            self.input_batch.block_table.append_row(req_data.new_block_ids,
                                                    req_index)
433
434
435
436
437
438
            # Add new_token_ids to token_ids_cpu.
            start_token_index = num_computed_tokens
            end_token_index = num_computed_tokens + len(req_data.new_token_ids)
            self.input_batch.token_ids_cpu[
                req_index,
                start_token_index:end_token_index] = req_data.new_token_ids
439
            self.input_batch.num_tokens_no_spec[req_index] = end_token_index
440
441
            # Add spec_token_ids to token_ids_cpu.
            spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
442
                req_id, ())
443
444
445
446
447
448
449
            if spec_token_ids:
                start_index = end_token_index
                end_token_index += len(spec_token_ids)
                self.input_batch.token_ids_cpu[
                    req_index, start_index:end_token_index] = spec_token_ids
            # NOTE(woosuk): `num_tokens` here may include spec decode tokens.
            self.input_batch.num_tokens[req_index] = end_token_index
450

451
452
        # Check if the batch has changed. If not, we can skip copying the
        # sampling metadata from CPU to GPU.
453
454
        batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0

455
456
        # Add the new or resumed requests to the persistent batch.
        # The smaller empty indices are filled first.
457
        removed_req_indices.sort(reverse=True)
458
459
460
461
462
463
464
465
466
467
468
469
470
        for req_id in req_ids_to_add:
            req_state = self.requests[req_id]
            if removed_req_indices:
                # Fill the empty index.
                req_index = removed_req_indices.pop()
            else:
                # Append to the end.
                req_index = None
            self.input_batch.add_request(req_state, req_index)

        # Condense the batched states if there are empty indices.
        if removed_req_indices:
            self.input_batch.condense(removed_req_indices)
471

472
473
474
475
476
477
478
        # Some attention backends (namely MLA) may want to separate requests
        # based on if the attention computation will be compute-bound or
        # memory-bound. This gives them a hook to do that.
        batch_reordered = self.attn_metadata_builder.reorder_batch(
            self.input_batch, scheduler_output)

        if batch_changed or batch_reordered:
479
            self.input_batch.refresh_sampling_metadata()
480

481
    def _prepare_inputs(
482
483
        self,
        scheduler_output: "SchedulerOutput",
484
485
    ) -> tuple[FlashAttentionMetadata, torch.Tensor,
               Optional[SpecDecodeMetadata]]:
486
487
488
489
490
491
492
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        assert total_num_scheduled_tokens > 0
        num_reqs = self.input_batch.num_reqs
        assert num_reqs > 0

        # OPTIMIZATION: Start copying the block table first.
        # This way, we can overlap the copy with the following CPU operations.
493
        self.input_batch.block_table.commit(num_reqs)
494
495

        # Get the number of scheduled tokens for each request.
496
497
498
499
        req_ids = self.input_batch.req_ids
        tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
        num_scheduled_tokens = np.array(tokens, dtype=np.int32)
        max_num_scheduled_tokens = max(tokens)
500
501
502

        # Get request indices.
        # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
503
504
        req_indices = np.repeat(self.arange_np[:num_reqs],
                                num_scheduled_tokens)
505
506
507

        # Get batched arange.
        # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
508
509
510
511
512
513
514
515
516
        # Equivalent to but faster than:
        # np.concatenate([np.arange(n) for n in num_scheduled_tokens])
        # Step 1. [2, 5, 3] -> [2, 7, 10]
        cu_num_tokens = np.cumsum(num_scheduled_tokens)
        # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
        cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
                                    num_scheduled_tokens)
        # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
517
518

        # Get positions.
519
        positions_np = self.positions_np[:total_num_scheduled_tokens]
520
521
522
523
        np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
               arange,
               out=positions_np)

524
525
        # Calculate M-RoPE positions.
        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
526
        if self.uses_mrope:
527
528
            self._calc_mrope_positions(scheduler_output)

529
530
531
532
        # Get token indices.
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
        # where M is the max_model_len.
533
534
        token_indices = (positions_np +
                         req_indices * self.input_batch.token_ids_cpu.shape[1])
535

536
537
538
539
        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
        torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
540
                           0,
541
542
                           torch.from_numpy(token_indices),
                           out=self.input_ids_cpu[:total_num_scheduled_tokens])
543
544

        # Calculate the slot mapping.
545
546
547
548
549
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
        # where K is the max_num_blocks_per_req and the block size is 2.
        # NOTE(woosuk): We can't simply use `token_indices // block_size` here
        # because M (max_model_len) is not necessarily divisible by block_size.
550
551
        block_table_indices = (req_indices * self.max_num_blocks_per_req +
                               positions_np // self.block_size)
552
553
        block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
        block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
554
555
556
557
        block_offsets = positions_np % self.block_size
        np.add(block_numbers * self.block_size,
               block_offsets,
               out=self.slot_mapping_np[:total_num_scheduled_tokens])
558
559

        # Prepare the attention metadata.
560
        self.query_start_loc_np[0] = 0
561
        self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
562

563
564
565
        self.seq_lens_np[:num_reqs] = (
            self.input_batch.num_computed_tokens_cpu[:num_reqs] +
            num_scheduled_tokens)
566
567
568
569

        # Copy the tensors to the GPU.
        self.input_ids[:total_num_scheduled_tokens].copy_(
            self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
570
        if self.uses_mrope:
571
572
573
574
575
576
577
578
579
            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
            self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
                self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
                non_blocking=True)
        else:
            # Common case (1D positions)
            self.positions[:total_num_scheduled_tokens].copy_(
                self.positions_cpu[:total_num_scheduled_tokens],
                non_blocking=True)
580

581
582
583
584
585
586
587
588
        # Prepare for cascade attention if enabled & beneficial.
        common_prefix_len = 0
        if self.cascade_attn_enabled:
            common_prefix_len = self._compute_cascade_attn_prefix_len(
                num_scheduled_tokens,
                scheduler_output.num_common_prefix_blocks,
            )

589
590
        attn_metadata = self.attn_metadata_builder.build(
            num_reqs=num_reqs,
591
            num_actual_tokens=total_num_scheduled_tokens,
592
            max_query_len=max_num_scheduled_tokens,
593
            common_prefix_len=common_prefix_len,
594
        )
595

596
597
        use_spec_decode = len(
            scheduler_output.scheduled_spec_decode_tokens) > 0
598
        if not use_spec_decode:
599
600
601
602
603
            # NOTE(woosuk): Due to chunked prefills, the batch may contain
            # partial requests. While we should not sample any token
            # from these partial requests, we do so for simplicity.
            # We will ignore the sampled tokens from the partial requests.
            # TODO: Support prompt logprobs.
604
            logits_indices = attn_metadata.query_start_loc[1:] - 1
605
606
607
608
609
610
611
612
613
614
615
616
617
618
            spec_decode_metadata = None
        else:
            # Get the number of draft tokens for each request.
            # Iterate over the dictionary rather than all requests since not all
            # requests have draft tokens.
            num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
            for req_id, draft_token_ids in (
                    scheduler_output.scheduled_spec_decode_tokens.items()):
                req_idx = self.input_batch.req_id_to_index[req_id]
                num_draft_tokens[req_idx] = len(draft_token_ids)

            spec_decode_metadata = self._calc_spec_decode_metadata(
                num_draft_tokens, cu_num_tokens)
            logits_indices = spec_decode_metadata.logits_indices
619

620
621
622
623
        # Hot-Swap lora model
        if self.lora_config:
            self.set_active_loras(self.input_batch, num_scheduled_tokens)

624
        return attn_metadata, logits_indices, spec_decode_metadata
625

626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
    def _compute_cascade_attn_prefix_len(
        self,
        num_scheduled_tokens: np.ndarray,
        num_common_prefix_blocks: int,
    ) -> int:
        """Compute the length of the common prefix for cascade attention.

        NOTE(woosuk): The common prefix length returned by this function
        represents the length used specifically for cascade attention, not the
        actual number of tokens shared between requests. When cascade attention
        is disabled (use_cascade=False), this function returns 0 even if
        requests share common tokens. Additionally, the common prefix length is
        truncated to a multiple of the block size and may be further truncated
        due to implementation details explained below.

        Args:
            num_scheduled_tokens: Number of tokens scheduled per request.
            num_common_prefix_blocks: Number of shared KV cache blocks.

        Returns:
            int: Length of common prefix in tokens.
        """
        common_prefix_len = num_common_prefix_blocks * self.block_size
        if common_prefix_len == 0:
            # Common case.
            return 0

        # NOTE(woosuk): Cascade attention uses two attention kernels: one
        # for the common prefix and the other for the rest. For the first
        # kernel, we concatenate all the query tokens (possibly from
        # different requests) and treat them as if they are from the same
        # request. Then, we use bi-directional attention to process the
        # common prefix in the KV cache. Importantly, this means that the
        # first kernel does not do any masking.

        # Consider the following example:
        # Request 1's input query: [D, E, X]
        # Request 1's kv cache: [A, B, C, D, E, X]
        # Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
        # Request 2's input query: [E, Y]
        # Request 2's kv cache: [A, B, C, D, E, Y]
        # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])

        # If we use [A, B, C, D, E] as the common prefix, then the
        # first kernel will compute the bi-directional attention between
        # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
        # However, this is wrong because D in Request 1 should not attend to
        # E in the common prefix (i.e., we need masking).
        # To avoid this, [A, B, C, D] should be the common prefix.
        # That is, the common prefix should be capped by the minimum
        # num_computed_tokens among the requests, and plus one to include
        # the first token of the query.

        # In practice, we use [A, B, C] as the common prefix, instead of
        # [A, B, C, D] (i.e., the common prefix is capped by the minimum
        # num_computed_tokens, without plus one).
        # This is because of an implementation detail: We want to always
        # use two kernels for cascade attention. Let's imagine:
        # Request 3's input query: [D]
        # Request 3's kv cache: [A, B, C, D]
686
        # Request 3's num_computed_tokens: 3 (i.e., [A, B, C])
687
688
689
690
691
692
693
694
695
696
697
698
        # If we use [A, B, C, D] as the common prefix for Request 1-3,
        # then Request 3 will be processed only by the first kernel,
        # and the second kernel will get an empty input. While this is not
        # a fundamental problem, our current implementation does not support
        # this case.
        num_reqs = len(num_scheduled_tokens)
        common_prefix_len = min(
            common_prefix_len,
            self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
        # common_prefix_len should be a multiple of the block size.
        common_prefix_len = (common_prefix_len // self.block_size *
                             self.block_size)
699
        use_cascade = self.attn_metadata_builder.use_cascade_attention(
700
701
702
703
            common_prefix_len=common_prefix_len,
            query_lens=num_scheduled_tokens,
            num_query_heads=self.num_query_heads,
            num_kv_heads=self.num_kv_heads,
704
            use_alibi=self.use_alibi,
705
            use_sliding_window=self.window_size is not None,
706
707
708
709
            num_sms=self.num_sms,
        )
        return common_prefix_len if use_cascade else 0

710
711
    def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
        mrope_pos_ptr = 0
712
        for index, req_id in enumerate(self.input_batch.req_ids):
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
            req = self.requests[req_id]
            assert req.mrope_positions is not None

            num_computed_tokens = \
                self.input_batch.num_computed_tokens_cpu[index]
            num_scheduled_tokens = \
                scheduler_output.num_scheduled_tokens[req_id]
            num_prompt_tokens = len(req.prompt_token_ids)

            if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
                prompt_part_len = max(0,
                                      num_prompt_tokens - num_computed_tokens)
                completion_part_len = max(
                    0, num_scheduled_tokens - prompt_part_len)
            else:
                prompt_part_len = num_scheduled_tokens
                completion_part_len = 0

            assert num_scheduled_tokens == prompt_part_len + completion_part_len

            if prompt_part_len > 0:
                # prompt's mrope_positions are pre-computed
                dst_start = mrope_pos_ptr
                dst_end = mrope_pos_ptr + prompt_part_len
                src_start = num_computed_tokens
                src_end = num_computed_tokens + prompt_part_len

                self.mrope_positions_cpu[:, dst_start:dst_end] = \
                    req.mrope_positions[:,src_start:src_end]

                mrope_pos_ptr += prompt_part_len

            if completion_part_len > 0:
                # compute completion's mrope_positions on-the-fly
                dst_start = mrope_pos_ptr
                dst_end = mrope_pos_ptr + completion_part_len

                self.mrope_positions_cpu[:, dst_start:dst_end] = \
                    MRotaryEmbedding.get_next_input_positions_tensor(
                        req.mrope_position_delta,
                        context_len=num_computed_tokens +
                        prompt_part_len,
                        seq_len=num_computed_tokens +
                        prompt_part_len +
                        completion_part_len,
                    )

                mrope_pos_ptr += completion_part_len

762
763
    def _calc_spec_decode_metadata(
        self,
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
        num_draft_tokens: np.ndarray,
        cu_num_scheduled_tokens: np.ndarray,
    ) -> SpecDecodeMetadata:
        # Inputs:
        # cu_num_scheduled_tokens:  [  4, 104, 107, 207, 209]
        # num_draft_tokens:         [  3,   0,   2,   0,   1]
        # Outputs:
        # cu_num_draft_tokens:      [  3,   3,   5,   5,   6]
        # logits_indices:           [  0,   1,   2,   3, 103, 104, 105, 106,
        #                            206, 207, 208]
        # target_logits_indices:    [  0,   1,   2,   5,   6,   9]
        # bonus_logits_indices:     [  3,   4,   7,   8,  10]

        # Compute the logits indices.
        # [4, 1, 3, 1, 2]
        num_sampled_tokens = num_draft_tokens + 1
        # Step 1. [4, 5, 8, 9, 11]
        cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
        total_num_sampled_tokens = cu_num_sampled_tokens[-1]
        # Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
        cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens,
                                    num_sampled_tokens)
        # Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
        arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets
        # Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
        logits_indices = np.repeat(
            cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens)
        # Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
        logits_indices += arange

        # Compute the bonus logits indices.
        bonus_logits_indices = cu_num_sampled_tokens - 1

        # Compute the draft logits indices.
        # [3, 3, 5, 5, 6]
        cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
        total_num_draft_tokens = cu_num_draft_tokens[-1]
        # [0, 0, 0, 3, 3, 5]
        cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens,
                                    num_draft_tokens)
        # [0, 1, 2, 0, 1, 0]
        arange = self.arange_np[:total_num_draft_tokens] - cumsums_offsets
        # [0, 0, 0, 5, 5, 9]
        target_logits_indices = np.repeat(
            cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens)
        # [0, 1, 2, 5, 6, 9]
        target_logits_indices += arange

        # TODO: Optimize the CPU -> GPU copy.
        cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
            self.device, non_blocking=True)
        logits_indices = torch.from_numpy(logits_indices).to(self.device,
                                                             non_blocking=True)
        target_logits_indices = torch.from_numpy(target_logits_indices).to(
            self.device, non_blocking=True)
        bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
820
821
            self.device, non_blocking=True)

822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
        # Compute the draft token ids.
        # draft_token_indices:      [  1,   2,   3, 105, 106, 208]
        draft_token_ids = self.input_ids[logits_indices]
        draft_token_ids = draft_token_ids[target_logits_indices + 1]

        metadata = SpecDecodeMetadata(
            draft_token_ids=draft_token_ids,
            num_draft_tokens=num_draft_tokens.tolist(),
            cu_num_draft_tokens=cu_num_draft_tokens,
            target_logits_indices=target_logits_indices,
            bonus_logits_indices=bonus_logits_indices,
            logits_indices=logits_indices,
        )
        return metadata

837
    def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
838
839
840
841
842
        scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
        if not scheduled_encoder_inputs:
            return

        # Batch the multi-modal inputs.
843
844
        mm_inputs = list[MultiModalKwargs]()
        req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
845
846
        for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
            req_state = self.requests[req_id]
847
848
849
850
851

            for mm_input_id in encoder_input_ids:
                mm_inputs.append(req_state.mm_inputs[mm_input_id])
                req_ids_pos.append(
                    (req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877

        # Batch mm inputs as much as we can: if a request in the batch has
        # multiple modalities or a different modality than the previous one,
        # we process it separately to preserve item order.
        # FIXME(ywang96): This is a hacky way to deal with multiple modalities
        # in the same batch while still being able to benefit from batching
        # multimodal inputs. The proper solution should be reordering the
        # encoder outputs.
        grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)

        encoder_outputs = []
        for grouped_mm_inputs in grouped_mm_inputs_list:
            batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
            batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
                                                           device=self.device)

            # Run the encoder.
            # `curr_group_outputs` is either of the following:
            # 1. A tensor of shape (num_items, feature_size, hidden_size)
            # in case feature_size is fixed across all multimodal items.
            # 2. A list or tuple (length: num_items) of tensors, each of shape
            # (feature_size, hidden_size) in case the feature size is dynamic
            # depending on the input multimodal items.
            curr_group_outputs = self.model.get_multimodal_embeddings(
                **batched_mm_inputs)

878
879
880
881
882
            sanity_check_mm_encoder_outputs(
                curr_group_outputs,
                expected_num_items=len(grouped_mm_inputs),
            )

883
884
            for output in curr_group_outputs:
                encoder_outputs.append(output)
885
886

        # Cache the encoder outputs.
887
888
889
890
        for (req_id, input_id, pos_info), output in zip(
                req_ids_pos,
                encoder_outputs,
        ):
891
892
893
            if req_id not in self.encoder_cache:
                self.encoder_cache[req_id] = {}

894
895
896
897
898
899
            self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
                output,
                is_embed=pos_info.is_embed,
            )

    def _gather_mm_embeddings(
900
901
        self,
        scheduler_output: "SchedulerOutput",
902
    ) -> list[torch.Tensor]:
903
        mm_embeds: list[torch.Tensor] = []
904
        for req_id in self.input_batch.req_ids:
905
906
907
908
909
910
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
                req_id]
            req_state = self.requests[req_id]
            num_computed_tokens = req_state.num_computed_tokens
            mm_positions = req_state.mm_positions
            for i, pos_info in enumerate(mm_positions):
911
912
                start_pos = pos_info.offset
                num_encoder_tokens = pos_info.length
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933

                # The encoder output is needed if the two ranges overlap:
                # [num_computed_tokens,
                #  num_computed_tokens + num_scheduled_tokens) and
                # [start_pos, start_pos + num_encoder_tokens)
                if start_pos >= num_computed_tokens + num_scheduled_tokens:
                    # The encoder output is not needed in this step.
                    break
                if start_pos + num_encoder_tokens <= num_computed_tokens:
                    # The encoder output is already processed and stored
                    # in the decoder's KV cache.
                    continue

                start_idx = max(num_computed_tokens - start_pos, 0)
                end_idx = min(
                    num_computed_tokens - start_pos + num_scheduled_tokens,
                    num_encoder_tokens)
                assert start_idx < end_idx
                assert req_id in self.encoder_cache
                assert i in self.encoder_cache[req_id]
                encoder_output = self.encoder_cache[req_id][i]
934
935
936
937
938
939
940
941
942
943

                if (is_embed := pos_info.is_embed) is not None:
                    is_embed = is_embed[start_idx:end_idx]

                mm_embeds_item = gather_mm_placeholders(
                    encoder_output[start_idx:end_idx],
                    is_embed=is_embed,
                )
                mm_embeds.append(mm_embeds_item)
        return mm_embeds
944

945
946
947
    def get_model(self) -> nn.Module:
        return self.model

948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
    def apply_grammar_bitmask(
        self,
        scheduler_output: "SchedulerOutput",
        logits: torch.Tensor,
    ):
        # Serialization of np.ndarray is much more efficient than a tensor,
        # so we receive it in that format.
        grammar_bitmask = scheduler_output.grammar_bitmask
        if grammar_bitmask is None:
            return

        # We receive the structured output bitmask from the scheduler, but the
        # indices of the requests in the batch may not match the indices of
        # the bitmask since the scheduler doesn't know how the gpu runner is
        # ordering the requests in the batch. We need to sort the bitmask to
        # match the order of the requests used here.
        struct_out_req_batch_indices: dict[str, int] = {}
        indices_match = True
        for req_id in self.input_batch.req_ids:
            mask_index = scheduler_output.structured_output_request_ids.get(
                req_id)
            if mask_index is None:
                # not a structured output request
                continue
            batch_index = self.input_batch.req_id_to_index[req_id]
            if batch_index != mask_index:
                indices_match = False
            struct_out_req_batch_indices[req_id] = batch_index

        if not indices_match:
            # Sort the bitmask to match the order of the requests
            sorted_bitmask = np.zeros_like(grammar_bitmask)
            for req_id, batch_index in struct_out_req_batch_indices.items():
                orig_index = scheduler_output.structured_output_request_ids[
                    req_id]
                sorted_bitmask[batch_index] = grammar_bitmask[orig_index]
            grammar_bitmask = sorted_bitmask

        grammar_bitmask = torch.from_numpy(grammar_bitmask)

        # TODO: compatibility with spec decode
        xgr.apply_token_bitmask_inplace(
            logits,
            grammar_bitmask.to(self.device, non_blocking=True),
            indices=list(struct_out_req_batch_indices.values()),
        )

995
996
997
998
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
999
        intermediate_tensors: Optional[IntermediateTensors] = None,
1000
    ) -> Union[ModelRunnerOutput, torch.Tensor]:
1001
1002
1003
1004
1005
        # Update KVConnector with the KVConnector metadata forward().
        if has_kv_transfer_group():
            get_kv_transfer_group().bind_connector_metadata(
                scheduler_output.kv_connector_metadata)

1006
        self._update_states(scheduler_output)
1007
        if not scheduler_output.total_num_scheduled_tokens:
1008
            # Return empty ModelRunnerOutput if there's no work to do.
1009
            return EMPTY_MODEL_RUNNER_OUTPUT
1010
1011

        # Prepare the decoder inputs.
1012
1013
        attn_metadata, logits_indices, spec_decode_metadata = (
            self._prepare_inputs(scheduler_output))
1014
1015
1016
1017
1018
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        if (self.use_cuda_graph
                and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
            # Use piecewise CUDA graphs.
            # Add padding to the batch size.
1019
            num_input_tokens = self.vllm_config.pad_for_cudagraph(
1020
1021
1022
1023
                num_scheduled_tokens)
        else:
            # Eager mode.
            num_input_tokens = num_scheduled_tokens
1024
1025
        attn_metadata.num_input_tokens = num_input_tokens

1026
1027
1028
1029
1030
1031
1032
1033
1034
        # _prepare_inputs may reorder the batch, so we must gather multi
        # modal outputs after that to ensure the correct order
        if self.is_multimodal_model:
            # Run the multimodal encoder if any.
            self._execute_mm_encoder(scheduler_output)
            mm_embeds = self._gather_mm_embeddings(scheduler_output)
        else:
            mm_embeds = []

1035
1036
1037
1038
1039
        if self.is_multimodal_model:
            # NOTE(woosuk): To unify token ids and soft tokens (vision
            # embeddings), we always use embeddings (rather than token ids)
            # as input to the multimodal model, even when the input is text.
            input_ids = self.input_ids[:num_scheduled_tokens]
1040
            if mm_embeds:
1041
                inputs_embeds = self.model.get_input_embeddings(
1042
                    input_ids, mm_embeds)
1043
1044
1045
1046
1047
1048
            else:
                inputs_embeds = self.model.get_input_embeddings(input_ids)
            # TODO(woosuk): Avoid the copy. Optimize.
            self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
            input_ids = None
1049
        else:
1050
1051
1052
1053
1054
1055
            # For text-only models, we use token ids as input.
            # While it is possible to use embeddings as input just like the
            # multimodal models, it is not desirable for performance since
            # then the embedding layer is not included in the CUDA graph.
            input_ids = self.input_ids[:num_input_tokens]
            inputs_embeds = None
1056
1057
1058
1059
        if self.uses_mrope:
            positions = self.mrope_positions[:, :num_input_tokens]
        else:
            positions = self.positions[:num_input_tokens]
1060

1061
1062
1063
        if get_pp_group().is_first_rank:
            intermediate_tensors = None
        else:
1064
1065
1066
1067
1068
            assert intermediate_tensors is not None
            assert self.intermediate_tensors is not None
            for k, v in intermediate_tensors.items():
                self.intermediate_tensors[k][:num_input_tokens].copy_(
                    v[:num_input_tokens], non_blocking=True)
1069
1070
1071
1072
1073
            intermediate_tensors = IntermediateTensors({
                k: v[:num_input_tokens]
                for k, v in self.intermediate_tensors.items()
            })

1074
1075
        # Run the decoder.
        # Use persistent buffers for CUDA graphs.
1076
        with set_forward_context(attn_metadata, self.vllm_config):
1077
            hidden_states = self.model(
1078
                input_ids=input_ids,
1079
                positions=positions,
1080
                intermediate_tensors=intermediate_tensors,
1081
                inputs_embeds=inputs_embeds,
1082
            )
1083
        if not get_pp_group().is_last_rank:
1084
            # For mid-pipeline stages, return the hidden states.
1085
            return hidden_states
1086

1087
        hidden_states = hidden_states[:num_scheduled_tokens]
1088
1089
        sample_hidden_states = hidden_states[logits_indices]
        logits = self.model.compute_logits(sample_hidden_states, None)
1090

1091
1092
1093
1094
        # Apply structured output bitmasks if present
        if scheduler_output.grammar_bitmask is not None:
            self.apply_grammar_bitmask(scheduler_output, logits)

1095
        # Sample the next token and get logprobs if needed.
1096
        sampling_metadata = self.input_batch.sampling_metadata
1097
        if spec_decode_metadata is None:
1098
1099
1100
1101
1102
            sampler_output = self.model.sample(
                logits=logits,
                sampling_metadata=sampling_metadata,
            )
        else:
1103
1104
1105
1106
            # When indexing with a tensor (bonus_logits_indices), PyTorch
            # creates a new tensor with separate storage from the original
            # logits tensor. This means any in-place operations on bonus_logits
            # won't affect the original logits tensor.
1107
            bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
1108
            sampler_output = self.model.sample(
1109
                logits=bonus_logits,
1110
1111
1112
                sampling_metadata=sampling_metadata,
            )
            bonus_token_ids = sampler_output.sampled_token_ids
1113

1114
1115
1116
            # Just like `bonus_logits`, `target_logits` is a new tensor with
            # separate storage from the original `logits` tensor. Therefore,
            # it is safe to update `target_logits` in place.
1117
            target_logits = logits[spec_decode_metadata.target_logits_indices]
1118
            output_token_ids = self.rejection_sampler(
1119
                spec_decode_metadata,
1120
                None,  # draft_probs
1121
                target_logits,
1122
                bonus_token_ids,
1123
1124
                sampling_metadata,
            )
1125
            sampler_output.sampled_token_ids = output_token_ids
1126
1127
1128

        # TODO(woosuk): The following loop can be slow since it iterates over
        # the requests one by one. Optimize.
1129
1130
        discard_sampled_tokens_req_indices = []
        for i, req_id in enumerate(self.input_batch.req_ids):
1131
1132
1133
            req_state = self.requests[req_id]
            seq_len = (req_state.num_computed_tokens +
                       scheduler_output.num_scheduled_tokens[req_id])
1134
            if seq_len < req_state.num_tokens:
1135
                # Ignore the sampled token for partial prefills.
1136
                # Rewind the generator state as if the token was not sampled.
1137
                # This relies on cuda-specific torch-internal impl details
1138
1139
1140
1141
1142
1143
                generator = self.input_batch.generators.get(i)
                if generator is not None:
                    generator.set_offset(generator.get_offset() - 4)
                # Record the index of the request that should not be sampled,
                # so that we could clear the sampled tokens before returning.
                discard_sampled_tokens_req_indices.append(i)
1144

1145
1146
        # NOTE: GPU -> CPU Sync happens here.
        # Move as many CPU operations as possible before this sync point.
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
        logprobs_tensors = sampler_output.logprobs_tensors
        logprobs_lists = logprobs_tensors.tolists() \
            if logprobs_tensors is not None else None

        # Compute prompt logprobs if needed.
        prompt_logprobs_dict = self._get_prompt_logprobs_dict(
            hidden_states,
            scheduler_output,
        )

1157
        # Get the valid generated tokens.
1158
1159
1160
        sampled_token_ids = sampler_output.sampled_token_ids
        max_gen_len = sampled_token_ids.shape[-1]
        if max_gen_len == 1:
1161
            # No spec decode tokens.
1162
1163
            valid_sampled_token_ids = sampled_token_ids.tolist()
        else:
1164
            # Includes spec decode tokens.
1165
            valid_sampled_token_ids = self.rejection_sampler.parse_output(
1166
1167
1168
                sampled_token_ids,
                self.input_batch.vocab_size,
            )
1169
1170
1171
        # Mask out the sampled tokens that should not be sampled.
        for i in discard_sampled_tokens_req_indices:
            valid_sampled_token_ids[i].clear()
1172

1173
        if not self.use_spec_decode:
1174
            # Speculative decoding is not enabled.
1175
            spec_token_ids = None
1176
1177
        elif self.speculative_config.method == "ngram":
            assert isinstance(self.drafter, NgramProposer)
1178
            spec_token_ids = self.generate_draft_token_ids(
1179
                valid_sampled_token_ids, sampling_metadata)
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
        elif self.speculative_config.method == "eagle":
            assert isinstance(self.drafter, EagleProposer)
            # TODO(woosuk): Refactor the loop.
            next_token_ids: list[int] = []
            for i, token_ids in enumerate(valid_sampled_token_ids):
                if token_ids:
                    # Common case.
                    next_token_id = token_ids[-1]
                else:
                    # Partial prefill (rare case).
                    # Get the next token id from the request state.
                    req_id = self.input_batch.req_ids[i]
                    req_state = self.requests[req_id]
                    seq_len = (req_state.num_computed_tokens +
                               scheduler_output.num_scheduled_tokens[req_id])
                    next_token_id = req_state.get_token_id(seq_len)
                next_token_ids.append(next_token_id)
            next_token_ids = torch.tensor(next_token_ids,
                                          dtype=torch.int32,
                                          device=self.device)

            if spec_decode_metadata is None:
                # input_ids can be None for multimodal models.
1203
1204
1205
                # We need to slice token_ids, positions, and hidden_states
                # because the eagle head does not use cuda graph and should
                # not include padding.
1206
                target_token_ids = self.input_ids[:num_scheduled_tokens]
1207
1208
                target_positions = positions[:num_scheduled_tokens]
                target_hidden_states = hidden_states[:num_scheduled_tokens]
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
                target_slot_mapping = attn_metadata.slot_mapping
                cu_num_tokens = attn_metadata.query_start_loc
            else:
                # TODO(woosuk): Refactor this.
                num_draft_tokens = spec_decode_metadata.num_draft_tokens
                num_rejected_tokens = [
                    n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
                    for i, n in enumerate(num_draft_tokens)
                ]
                num_rejected_tokens = torch.tensor(
                    num_rejected_tokens,
                    dtype=torch.int32,
                    device=self.device,
                )
                cu_num_tokens, token_indices = self.drafter.prepare_inputs(
                    attn_metadata.query_start_loc,
                    num_rejected_tokens,
                )
                target_token_ids = self.input_ids[token_indices]
                target_positions = positions[token_indices]
                target_hidden_states = hidden_states[token_indices]
                target_slot_mapping = attn_metadata.slot_mapping[token_indices]

            draft_token_ids, draft_probs = self.drafter.propose(
                target_token_ids=target_token_ids,
                target_positions=target_positions,
                target_hidden_states=target_hidden_states,
                target_slot_mapping=target_slot_mapping,
                next_token_ids=next_token_ids,
                cu_num_tokens=cu_num_tokens,
                block_table=attn_metadata.block_table,
                sampling_metadata=sampling_metadata,
            )
            spec_token_ids = draft_token_ids.tolist()
            # TODO(woosuk): Cache draft_probs and use it for rejection sampling
            # in the next step.
            del draft_probs
1246

1247
1248
1249
1250
        # Clear KVConnector state after all KVs are generated.
        if has_kv_transfer_group():
            get_kv_transfer_group().clear_connector_metadata()

1251
        return ModelRunnerOutput(
1252
            req_ids=self.input_batch.req_ids,
1253
            req_id_to_index=self.input_batch.req_id_to_index,
1254
            sampled_token_ids=valid_sampled_token_ids,
1255
            spec_token_ids=spec_token_ids,
1256
1257
            logprobs=logprobs_lists,
            prompt_logprobs_dict=prompt_logprobs_dict,
1258
1259
        )

1260
1261
    def generate_draft_token_ids(
        self,
1262
        sampled_token_ids: list[list[int]],
1263
        sampling_metadata: SamplingMetadata,
1264
    ) -> list[list[int]]:
1265
        # TODO(woosuk): Optimize.
1266
        draft_token_ids: list[list[int]] = []
1267
1268
1269
        for i, sampled_ids in enumerate(sampled_token_ids):
            num_sampled_ids = len(sampled_ids)
            if not num_sampled_ids:
1270
1271
1272
1273
                # Skip speculative decoding.
                draft_token_ids.append([])
                continue

1274
1275
            # Skip requests that require sampling parameters that are not
            # supported with speculative decoding.
1276
1277
1278
1279
1280
            req_id = self.input_batch.req_ids[i]
            if not is_spec_decode_supported(req_id, self.input_batch):
                draft_token_ids.append([])
                continue

1281
1282
            # Add sampled_token_ids to token_ids_cpu.
            start_idx = self.input_batch.num_tokens_no_spec[i]
1283
            end_idx = start_idx + num_sampled_ids
1284
1285
1286
1287
1288
            if end_idx >= self.max_model_len:
                # Skip requests that have already reached the max model length.
                draft_token_ids.append([])
                continue

1289
            self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
1290
            drafter_output = self.drafter.propose(
1291
                self.input_batch.token_ids_cpu[i, :end_idx])
1292
1293
1294
1295
1296
1297
            if drafter_output is None or len(drafter_output) == 0:
                draft_token_ids.append([])
            else:
                draft_token_ids.append(drafter_output.tolist())
        return draft_token_ids

1298
1299
1300
    def load_model(self) -> None:
        logger.info("Starting to load model %s...", self.model_config.model)
        with DeviceMemoryProfiler() as m:  # noqa: SIM117
1301
            time_before_load = time.perf_counter()
Joe Runde's avatar
Joe Runde committed
1302
            self.model = get_model(vllm_config=self.vllm_config)
1303
1304
1305
1306
1307
1308
            if self.lora_config:
                self.model = self.load_lora_model(self.model,
                                                  self.model_config,
                                                  self.scheduler_config,
                                                  self.lora_config,
                                                  self.device)
1309
1310
1311
            if hasattr(self, "drafter"):
                logger.info("Loading drafter model...")
                self.drafter.load_model(self.model)
1312
            time_after_load = time.perf_counter()
1313
        self.model_memory_usage = m.consumed_memory
1314
1315
        logger.info("Model loading took %.4f GiB and %.6f seconds",
                    self.model_memory_usage / GiB_bytes,
1316
                    time_after_load - time_before_load)
1317

1318
1319
1320
1321
    def _get_prompt_logprobs_dict(
        self,
        hidden_states: torch.Tensor,
        scheduler_output: "SchedulerOutput",
1322
    ) -> dict[str, Optional[LogprobsTensors]]:
1323
1324
1325
1326
        num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
        if not num_prompt_logprobs_dict:
            return {}

1327
        in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
1328
        prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342

        # Since prompt logprobs are a rare feature, prioritize simple,
        # maintainable loop over optimal performance.
        completed_prefill_reqs = []
        for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():

            num_tokens = scheduler_output.num_scheduled_tokens[req_id]

            # Get metadata for this request.
            request = self.requests[req_id]
            num_prompt_tokens = len(request.prompt_token_ids)
            prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
                self.device, non_blocking=True)

1343
1344
1345
1346
1347
1348
1349
1350
1351
            # Set up target LogprobsTensors object.
            logprobs_tensors = in_progress_dict.get(req_id)
            if not logprobs_tensors:
                # Create empty logprobs CPU tensors for the entire prompt.
                # If chunked, we'll copy in slice by slice.
                logprobs_tensors = LogprobsTensors.empty_cpu(
                    num_prompt_tokens - 1, num_prompt_logprobs + 1)
                in_progress_dict[req_id] = logprobs_tensors

1352
            # Determine number of logits to retrieve.
1353
1354
            start_idx = request.num_computed_tokens
            start_tok = start_idx + 1
1355
            num_remaining_tokens = num_prompt_tokens - start_tok
1356
            if num_tokens <= num_remaining_tokens:
1357
                # This is a chunk, more tokens remain.
1358
1359
1360
                # In the == case, there are no more prompt logprobs to produce
                # but we want to defer returning them to the next step where we
                # have new generated tokens to return.
1361
1362
1363
1364
1365
                num_logits = num_tokens
            else:
                # This is the last chunk of prompt tokens to return.
                num_logits = num_remaining_tokens
                completed_prefill_reqs.append(req_id)
1366
1367
1368
1369
1370
1371
1372
                prompt_logprobs_dict[req_id] = logprobs_tensors

            if num_logits <= 0:
                # This can happen for the final chunk if we prefilled exactly
                # (num_prompt_tokens - 1) tokens for this request in the prior
                # step. There are no more prompt logprobs to produce.
                continue
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392

            # Get the logits corresponding to this req's prompt tokens.
            # If this is a partial request (i.e. chunked prefill),
            # then there is prompt logprob generated for each index.
            req_idx = self.input_batch.req_id_to_index[req_id]
            offset = self.query_start_loc_np[req_idx].item()
            prompt_hidden_states = hidden_states[offset:offset + num_logits]
            logits = self.model.compute_logits(prompt_hidden_states, None)

            # Get the "target" tokens for each index. For prompt at index i,
            # the token at prompt index i+1 is the "sampled" token we want
            # to gather the logprob for.
            tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]

            # Compute prompt logprobs.
            logprobs = self.model.sampler.compute_logprobs(logits)
            token_ids, logprobs, ranks = self.model.sampler.gather_logprobs(
                logprobs, num_prompt_logprobs, tgt_token_ids)

            # Transfer GPU->CPU async.
1393
1394
1395
1396
1397
1398
1399
            chunk_slice = slice(start_idx, start_idx + num_logits)
            logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
                token_ids, non_blocking=True)
            logprobs_tensors.logprobs[chunk_slice].copy_(logprobs,
                                                         non_blocking=True)
            logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
                ranks, non_blocking=True)
1400
1401
1402
1403
1404

        # Remove requests that have completed prefill from the batch
        # num_prompt_logprobs_dict.
        for req_id in completed_prefill_reqs:
            del num_prompt_logprobs_dict[req_id]
1405
            del in_progress_dict[req_id]
1406
1407

        # Must synchronize the non-blocking GPU->CPU transfers.
1408
1409
        if prompt_logprobs_dict:
            torch.cuda.synchronize()
1410
1411
1412

        return prompt_logprobs_dict

1413
1414
1415
1416
1417
    @torch.inference_mode()
    def _dummy_run(
        self,
        num_tokens: int,
    ) -> torch.Tensor:
1418

1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
        # Set num_scheduled_tokens based on num_tokens and max_num_seqs
        # for dummy run with LoRA so that the num_reqs collectively
        # has num_tokens in total.
        assert num_tokens <= self.scheduler_config.max_num_batched_tokens
        max_num_reqs = self.scheduler_config.max_num_seqs
        num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens
        min_tokens_per_req = num_tokens // num_reqs
        num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
        num_scheduled_tokens_list[-1] += num_tokens % num_reqs
        assert sum(num_scheduled_tokens_list) == num_tokens
        assert len(num_scheduled_tokens_list) == num_reqs
        num_scheduled_tokens = np.array(num_scheduled_tokens_list,
                                        dtype=np.int32)
1432

1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
        with self.maybe_dummy_run_with_lora(self.lora_config,
                                            num_scheduled_tokens):
            model = self.model
            if self.is_multimodal_model:
                input_ids = None
                inputs_embeds = self.inputs_embeds[:num_tokens]
            else:
                input_ids = self.input_ids[:num_tokens]
                inputs_embeds = None
            if self.uses_mrope:
                positions = self.mrope_positions[:, :num_tokens]
            else:
                positions = self.positions[:num_tokens]

            if get_pp_group().is_first_rank:
                intermediate_tensors = None
            else:
                if self.intermediate_tensors is None:
                    self.intermediate_tensors = (
                        self.model.make_empty_intermediate_tensors(
                            batch_size=self.max_num_tokens,
                            dtype=self.model_config.dtype,
                            device=self.device))
                intermediate_tensors = IntermediateTensors({
                    k: v[:num_tokens]
                    for k, v in self.intermediate_tensors.items()
                })

            with set_forward_context(None,
                                     self.vllm_config,
                                     num_tokens=num_tokens):
                hidden_states = model(
                    input_ids=input_ids,
                    positions=positions,
                    intermediate_tensors=intermediate_tensors,
                    inputs_embeds=inputs_embeds,
                )

        logit_indices = np.cumsum(num_scheduled_tokens) - 1
        return hidden_states[logit_indices]

    @torch.inference_mode()
    def _dummy_sampler_run(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:

        logits = self.model.compute_logits(hidden_states, None)
        num_reqs = logits.size(0)

        dummy_tensors = lambda v: torch.full(
            (num_reqs, ), v, device=self.device)

        dummy_metadata = SamplingMetadata(
            temperature=dummy_tensors(0.5),
            all_greedy=False,
            all_random=False,
            top_p=dummy_tensors(0.9),
            top_k=dummy_tensors(logits.size(1) - 1),
            min_p=None,
            generators={},
            max_num_logprobs=None,
            no_penalties=True,
            prompt_token_ids=None,
            frequency_penalties=dummy_tensors(0.1),
            presence_penalties=dummy_tensors(0.1),
            repetition_penalties=dummy_tensors(0.1),
            output_token_ids=[[] for _ in range(num_reqs)],
            min_tokens={},
            logit_bias=[None for _ in range(num_reqs)],
            allowed_token_ids_mask=None,
            bad_words_token_ids={},
        )
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
        try:
            sampler_output = self.model.sample(
                logits=logits, sampling_metadata=dummy_metadata)
        except RuntimeError as e:
            if 'out of memory' in str(e):
                raise RuntimeError(
                    "CUDA out of memory occurred when warming up sampler with "
                    f"{num_reqs} dummy requests. Please try lowering "
                    "`max_num_seqs` or `gpu_memory_utilization` when "
                    "initializing the engine.") from e
            else:
                raise e
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
        if self.use_spec_decode:
            draft_token_ids = [[0] for _ in range(num_reqs)]
            dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy(
                draft_token_ids, self.device)

            num_tokens = sum(len(ids) for ids in draft_token_ids)
            # draft_probs = torch.randn(
            #     num_tokens, logits.shape[-1], device=self.device,
            #     dtype=logits.dtype)
            draft_probs = None
            target_logits = torch.randn(num_tokens,
                                        logits.shape[-1],
                                        device=self.device,
                                        dtype=logits.dtype)
            # NOTE(woosuk): Here, we should use int32 because the sampler uses
            # int32 for bonus_token_ids. If the dtype mismatches, re-compilation
            # will occur at runtime.
            bonus_token_ids = torch.zeros(num_reqs,
                                          device=self.device,
                                          dtype=torch.int32)
            self.rejection_sampler(
                dummy_spec_decode_metadata,
                draft_probs,
                target_logits,
                bonus_token_ids,
                dummy_metadata,
            )
1545
        return sampler_output
1546
1547

    def profile_run(self) -> None:
1548
        # Profile with multimodal encoder & encoder cache.
1549
1550
1551
        # TODO: handle encoder-decoder models once we support them.
        if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
                and self.encoder_cache_size > 0):
1552

1553
            # NOTE: Currently model is profiled with a single non-text
1554
1555
            # modality with the max possible input tokens even when
            # it supports multiple.
1556
1557
            max_tokens_by_modality_dict = self.mm_registry \
                .get_max_tokens_per_item_by_nonzero_modality(self.model_config)
1558
1559
1560
1561
            dummy_data_modality, max_tokens_per_mm_item = max(
                max_tokens_by_modality_dict.items(), key=lambda item: item[1])

            # Check how many items of this modality can be supported by
1562
1563
1564
1565
1566
1567
            # the encoder budget.
            encoder_budget = min(self.max_num_encoder_input_tokens,
                                 self.encoder_cache_size)

            max_num_mm_items_encoder_budget = cdiv(encoder_budget,
                                                   max_tokens_per_mm_item)
1568
1569
1570

            # Check how many items of this modality can be supported by
            # the decoder budget.
1571
1572
            max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
                self.model_config)[dummy_data_modality]
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582

            # NOTE: We do not consider max_num_batched_tokens on purpose
            # because the multimodal embeddings can be generated in advance
            # and chunked prefilled.
            max_num_mm_items_decoder_budget = self.max_num_reqs * \
                max_mm_items_per_req

            max_num_mm_items = min(max_num_mm_items_encoder_budget,
                                   max_num_mm_items_decoder_budget)

1583
1584
1585
1586
1587
1588
            logger.info(
                "Encoder cache will be initialized with a budget of %s tokens,"
                " and profiled with %s %s items of the maximum feature size.",
                encoder_budget, max_num_mm_items, dummy_data_modality)

            # Create dummy batch of multimodal inputs.
1589
            dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data(
1590
1591
                model_config=self.model_config,
                seq_len=self.max_num_tokens,
1592
1593
1594
1595
                mm_counts={
                    dummy_data_modality: 1
                },
            ).multi_modal_data
1596

1597
            batched_dummy_mm_inputs = MultiModalKwargs.batch(
1598
                [dummy_mm_kwargs] * max_num_mm_items)
1599
1600
1601
1602
1603
1604
            batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
                batched_dummy_mm_inputs, device=self.device)

            # Run multimodal encoder.
            dummy_encoder_outputs = self.model.get_multimodal_embeddings(
                **batched_dummy_mm_inputs)
1605
1606
1607
1608
1609

            sanity_check_mm_encoder_outputs(
                dummy_encoder_outputs,
                expected_num_items=max_num_mm_items,
            )
1610
1611
1612
1613

            # Cache the dummy encoder outputs.
            self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))

1614
1615
1616
1617
1618
1619
1620
1621
        hidden_states = self._dummy_run(self.max_num_tokens)
        if get_pp_group().is_last_rank:
            sampler_output = self._dummy_sampler_run(hidden_states)
        else:
            sampler_output = None
        torch.cuda.synchronize()
        del hidden_states, sampler_output
        self.encoder_cache.clear()
1622
        gc.collect()
1623
1624

    def capture_model(self) -> None:
1625
1626
        if not self.use_cuda_graph:
            logger.warning(
1627
                "Skipping CUDA graph capture. Please add "
1628
                "-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE)
1629
1630
1631
1632
1633
            return

        start_time = time.perf_counter()
        start_free_gpu_memory = torch.cuda.mem_get_info()[0]

1634
1635
1636
        # Trigger CUDA graph capture for specific shapes.
        # Capture the large shapes first so that the smaller shapes
        # can reuse the memory pool allocated for the large shapes.
1637
        with graph_capture(device=self.device):
1638
            for num_tokens in reversed(self.cudagraph_batch_sizes):
1639
1640
                for _ in range(self.vllm_config.compilation_config.
                               cudagraph_num_of_warmups):
1641
1642
                    self._dummy_run(num_tokens)
                self._dummy_run(num_tokens)
1643
1644
1645
1646
1647
1648
1649
1650

        end_time = time.perf_counter()
        end_free_gpu_memory = torch.cuda.mem_get_info()[0]
        elapsed_time = end_time - start_time
        cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
        # This usually takes 5~20 seconds.
        logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
                    elapsed_time, cuda_graph_size / (1 << 30))
1651

1652
1653
1654
1655
    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize KV cache based on `kv_cache_config`.
        Args:
1656
            kv_cache_config: Configuration for the KV cache, including the KV
1657
1658
            cache size of each layer
        """
1659
        if len(kv_cache_config.kv_cache_groups) > 1:
1660
1661
1662
1663
            raise NotImplementedError(
                "Hybrid models with more than one KV cache type are not "
                "supported yet.")

1664
        kv_caches: dict[str, torch.Tensor] = {}
1665

1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
        for kv_cache_group in kv_cache_config.kv_cache_groups:
            kv_cache_spec = kv_cache_group.kv_cache_spec
            for layer_name in kv_cache_group.layer_names:
                tensor_config = kv_cache_config.tensors[layer_name]
                assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
                num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
                # `num_blocks` is the number of blocks the model runner can use.
                # `kv_cache_config.num_blocks` is the number of blocks that
                # KVCacheManager may allocate.
                # Since different GPUs may have different number of layers and
                # different memory capacities, `num_blocks` can be different on
                # different GPUs, and `kv_cache_config.num_blocks` is set to
                # the min of all `num_blocks`. Verify it here.
                assert num_blocks >= kv_cache_config.num_blocks
1680
                if isinstance(kv_cache_spec, AttentionSpec):
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
                    kv_cache_shape = self.attn_backend.get_kv_cache_shape(
                        num_blocks, kv_cache_spec.block_size,
                        kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
                    dtype = kv_cache_spec.dtype
                    kv_caches[layer_name] = torch.zeros(kv_cache_shape,
                                                        dtype=dtype,
                                                        device=self.device)
                else:
                    # TODO: add new branches when introducing more types of
                    # KV cache specs.
                    raise ValueError("Unknown KV cache spec type.")
1692

1693
        bind_kv_cache(
1694
            kv_caches,
1695
            self.vllm_config.compilation_config.static_forward_context,
1696
1697
            self.kv_caches)

1698
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
1699
        """
1700
        Generates the KVCacheSpec by parsing the kv cache format from each
1701
1702
        Attention module in the static forward context.
        Returns:
1703
            KVCacheSpec: A dictionary mapping layer names to their KV cache
1704
1705
1706
1707
1708
            format. Layers that do not need KV cache are not included.
        """

        forward_ctx = self.vllm_config.compilation_config.static_forward_context
        block_size = self.vllm_config.cache_config.block_size
1709
        use_mla = self.vllm_config.model_config.use_mla
1710
        kv_cache_spec: dict[str, KVCacheSpec] = {}
1711
        for layer_name, attn_module in forward_ctx.items():
1712
1713
1714
            if isinstance(attn_module, FusedMoE):
                continue

1715
            # TODO: Support other attention modules, e.g., sliding window,
1716
            # cross-attention
1717
1718
            assert isinstance(attn_module, Attention)
            if attn_module.attn_type == AttentionType.DECODER:
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
                if attn_module.sliding_window is not None:
                    kv_cache_spec[layer_name] = SlidingWindowSpec(
                        block_size=block_size,
                        num_kv_heads=attn_module.num_kv_heads,
                        head_size=attn_module.head_size,
                        dtype=self.kv_cache_dtype,
                        sliding_window=attn_module.sliding_window,
                        use_mla=use_mla)
                else:
                    kv_cache_spec[layer_name] = FullAttentionSpec(
                        block_size=block_size,
                        num_kv_heads=attn_module.num_kv_heads,
                        head_size=attn_module.head_size,
                        dtype=self.kv_cache_dtype,
                        use_mla=use_mla)
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
            elif attn_module.attn_type in (AttentionType.ENCODER,
                                           AttentionType.ENCODER_ONLY):
                # encoder-only attention does not need KV cache.
                continue
            elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
                raise NotImplementedError
            else:
                raise ValueError(
                    f"Unknown attention type: {attn_module.attn_type}")

        return kv_cache_spec