gpu_model_runner.py 78.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import gc
4
import time
5
import weakref
6
from typing import TYPE_CHECKING, Optional, Union
7
8
9
10
11
12

import numpy as np
import torch
import torch.distributed
import torch.nn as nn

13
from vllm.attention import AttentionType, get_attn_backend
14
from vllm.attention.layer import Attention
15
from vllm.config import CompilationLevel, VllmConfig
16
from vllm.distributed.parallel_state import get_pp_group, graph_capture
17
18
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
19
from vllm.model_executor.layers.fused_moe import FusedMoE
20
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
21
from vllm.model_executor.model_loader import get_model
22
23
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
24
from vllm.multimodal.utils import group_mm_inputs_by_modality
25
from vllm.sampling_params import SamplingType
26
from vllm.sequence import IntermediateTensors
27
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
28
29
                        GiB_bytes, LayerBlockType, LazyLoader, cdiv,
                        check_use_alibi, is_pin_memory_available)
30
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
31
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
32
33
34
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
                                        KVCacheConfig, KVCacheSpec,
                                        SlidingWindowSpec)
35
36
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
                             ModelRunnerOutput)
37
from vllm.v1.sample.metadata import SamplingMetadata
38
from vllm.v1.sample.rejection_sampler import RejectionSampler
39
from vllm.v1.spec_decode.eagle import EagleProposer
40
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
41
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
42
from vllm.v1.spec_decode.utils import is_spec_decode_supported
43
from vllm.v1.utils import bind_kv_cache
44
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
45
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
46

47
48
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
                    scatter_mm_placeholders)
49

50
if TYPE_CHECKING:
51
52
    import xgrammar as xgr

53
    from vllm.v1.core.sched.output import SchedulerOutput
54
55
else:
    xgr = LazyLoader("xgr", globals(), "xgrammar")
56
57
58
59

logger = init_logger(__name__)


60
class GPUModelRunner(LoRAModelRunnerMixin):
61
62
63

    def __init__(
        self,
64
        vllm_config: VllmConfig,
65
        device: torch.device,
66
    ):
67
68
69
70
71
72
73
74
75
76
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.prompt_adapter_config = vllm_config.prompt_adapter_config
        self.observability_config = vllm_config.observability_config
77

78
79
80
81
        from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
        set_cpu_offload_max_bytes(
            int(self.cache_config.cpu_offload_gb * 1024**3))

82
83
84
85
        model_config = self.model_config
        cache_config = self.cache_config
        scheduler_config = self.scheduler_config
        parallel_config = self.parallel_config
86
        self.device = device
87
88
89
90
91
92
93
94
        self.pin_memory = is_pin_memory_available()
        self.dtype = self.model_config.dtype
        if cache_config.cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        else:
            self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
                cache_config.cache_dtype]

95
96
        # NOTE(woosuk): sliding_window is None for models with interleaved
        # attention. Use interleaved_sliding_window instead.
97
        self.sliding_window = model_config.get_sliding_window()
98
99
100
101
102
103
        self.interleaved_sliding_window = getattr(
            model_config.hf_text_config, "interleaved_sliding_window", None)
        self.window_size = (self.sliding_window
                            or self.interleaved_sliding_window)

        self.is_multimodal_model = model_config.is_multimodal_model
104
105
106
107
        self.block_size = cache_config.block_size
        self.max_model_len = model_config.max_model_len
        self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
        self.max_num_tokens = scheduler_config.max_num_batched_tokens
108
        self.max_num_reqs = scheduler_config.max_num_seqs
109
110

        # Model-related.
111
112
        self.num_attn_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)
113
114
        self.num_query_heads = model_config.get_num_attention_heads(
            parallel_config)
115
116
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
        self.head_size = model_config.get_head_size()
117
        self.hidden_size = model_config.get_hidden_size()
118
        self.attention_chunk_size = model_config.attention_chunk_size
119

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        self.attn_backend = get_attn_backend(
            self.head_size,
            self.dtype,
            self.kv_cache_dtype,
            self.block_size,
            self.model_config.is_attention_free,
            use_mla=self.model_config.use_mla,
        )
        if self.attn_backend is None:
            error_msg = (
                f"Error with get_att_backend: {self.head_size=}, "
                f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
                f"{self.model_config.is_attention_free=}, "
                f"{self.model_config.use_mla=}")
            logger.error(error_msg)
            raise NotImplementedError(
                "Non-Attention backend is not supported by V1 GPUModelRunner.")

        self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
            weakref.proxy(self))
140
        self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
141

142
        # Multi-modal data support
143
        self.mm_registry = MULTIMODAL_REGISTRY
144
        self.uses_mrope = model_config.uses_mrope
145

146
147
148
        encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
            model_config=model_config,
            scheduler_config=scheduler_config,
149
            mm_registry=self.mm_registry,
150
151
152
        )
        self.max_num_encoder_input_tokens = encoder_compute_budget
        self.encoder_cache_size = encoder_cache_size
153
154
155

        # Lazy initialization
        # self.model: nn.Module  # Set after load_model
156
        self.kv_caches: list[torch.Tensor] = []
157
        # req_id -> (input_id -> encoder_output)
158
        self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
159

160
161
162
        # Set up speculative decoding.
        self.use_spec_decode = False
        if self.speculative_config:
163
164
            self.use_spec_decode = True
            if get_pp_group().is_last_rank:
165
166
167
168
169
170
171
172
                if self.speculative_config.method == "ngram":
                    self.drafter = NgramProposer(self.vllm_config)
                elif self.speculative_config.method == "eagle":
                    self.drafter = EagleProposer(self.vllm_config,
                                                 self.device)  # type: ignore
                else:
                    raise ValueError("Unknown speculative decoding method: "
                                     f"{self.speculative_config.method}")
173
                self.rejection_sampler = RejectionSampler()
174

175
        # Request states.
176
        self.requests: dict[str, CachedRequestState] = {}
177
178
        # Persistent batch.
        self.input_batch = InputBatch(
179
            max_num_reqs=self.max_num_reqs,
180
181
182
183
            max_model_len=self.max_model_len,
            max_num_blocks_per_req=self.max_num_blocks_per_req,
            device=self.device,
            pin_memory=self.pin_memory,
184
            vocab_size=model_config.get_vocab_size(),
185
186
        )

187
        self.use_cuda_graph = (self.vllm_config.compilation_config.level
188
189
190
                               == CompilationLevel.PIECEWISE
                               and not self.model_config.enforce_eager)
        # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
191
192
193
194
        # The convention is different.
        # self.cudagraph_batch_sizes sorts in ascending order.
        # The batch sizes in the config are in descending order.
        self.cudagraph_batch_sizes = list(
195
196
            reversed(
                self.vllm_config.compilation_config.cudagraph_capture_sizes))
197

198
199
200
201
        # Cache the device properties.
        self.device_properties = torch.cuda.get_device_properties(self.device)
        self.num_sms = self.device_properties.multi_processor_count

202
203
204
205
        # Persistent buffers for CUDA graphs.
        self.input_ids = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int32,
                                     device=self.device)
206
207
208
        self.positions = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int64,
                                     device=self.device)
209
210
        # None in the first PP rank. The rest are set after load_model.
        self.intermediate_tensors: Optional[IntermediateTensors] = None
211
212

        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
213
        if self.uses_mrope:
Roger Wang's avatar
Roger Wang committed
214
215
216
217
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
218
219
220
221
222
223

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
Roger Wang's avatar
Roger Wang committed
224
            self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1),
225
226
                                               dtype=torch.int64,
                                               device=self.device)
Roger Wang's avatar
Roger Wang committed
227
228
229
230
231
            self.mrope_positions_cpu = torch.zeros(
                (3, self.max_num_tokens + 1),
                dtype=torch.int64,
                device="cpu",
                pin_memory=self.pin_memory)
232

233
234
235
        # Only relevant for models using ALiBi (e.g, MPT)
        self.use_alibi = check_use_alibi(model_config)

236
237
238
239
        self.inputs_embeds = torch.zeros(
            (self.max_num_tokens, self.hidden_size),
            dtype=self.dtype,
            device=self.device)
240

241
        # OPTIMIZATION: Cache the tensors rather than creating them every step.
242
        self.arange_np = np.arange(max(self.max_num_reqs + 1,
243
244
                                       self.max_model_len,
                                       self.max_num_tokens),
245
246
247
248
                                   dtype=np.int32)
        # NOTE(woosuk): These tensors are "stateless", i.e., they are literally
        # a faster version of creating a new tensor every time. Thus, we should
        # not make any assumptions about the values in these tensors.
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
        self.input_ids_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int32,
                                         device="cpu",
                                         pin_memory=self.pin_memory)
        self.input_ids_np = self.input_ids_cpu.numpy()
        self.positions_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int64,
                                         device="cpu",
                                         pin_memory=self.pin_memory)
        self.positions_np = self.positions_cpu.numpy()
        self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
                                            dtype=torch.int32,
                                            device="cpu",
                                            pin_memory=self.pin_memory)
        self.slot_mapping_np = self.slot_mapping_cpu.numpy()
        self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
                                               dtype=torch.int32,
                                               device="cpu",
                                               pin_memory=self.pin_memory)
        self.query_start_loc_np = self.query_start_loc_cpu.numpy()
269
270
271
272
273
        self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
                                        dtype=torch.int32,
                                        device="cpu",
                                        pin_memory=self.pin_memory)
        self.seq_lens_np = self.seq_lens_cpu.numpy()
274

275
    def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
276
277
278
279
280
281
        """Update the cached states and the persistent batch with the scheduler
        output.

        The updated states are used by the `_prepare_inputs` function to create
        the input GPU tensors for the model.

282
283
        The SamplingMetadata is updated and copied to the GPU if there is a
        new/resumed/paused/finished request in the batch.
284
285
        """
        # Remove finished requests from the cached states.
286
287
        for req_id in scheduler_output.finished_req_ids:
            self.requests.pop(req_id, None)
288
            self.encoder_cache.pop(req_id, None)
289
290
291
292
293
294
        # Remove the finished requests from the persistent batch.
        # NOTE(woosuk): There could be an edge case where finished_req_ids and
        # scheduled_req_ids overlap. This happens when a request is aborted and
        # then resubmitted with the same ID. In this case, we treat them as two
        # distinct requests - clearing the cached states for the first request
        # and handling the second as a new request.
295
        removed_req_indices: list[int] = []
296
297
298
299
        for req_id in scheduler_output.finished_req_ids:
            req_index = self.input_batch.remove_request(req_id)
            if req_index is not None:
                removed_req_indices.append(req_index)
300
301
302
303
304
305
306
307

        # Free the cached encoder outputs.
        for req_id, input_id in scheduler_output.free_encoder_input_ids:
            encoder_outputs = self.encoder_cache.get(req_id)
            if encoder_outputs is not None:
                encoder_outputs.pop(input_id, None)
                if not encoder_outputs:
                    self.encoder_cache.pop(req_id, None)
308

309
310
311
312
313
314
315
316
317
318
319
320
321
        # Remove the unscheduled requests from the persistent batch.
        # NOTE(woosuk): The unscheduled requests are either preempted requests
        # or running requests that are not scheduled in this step. We remove
        # them from the persistent batch but keep their cached states since
        # they will be scheduled again sometime in the future.
        scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
        cached_req_ids = self.input_batch.req_id_to_index.keys()
        unscheduled_req_ids = cached_req_ids - scheduled_req_ids
        # NOTE(woosuk): The persistent batch optimization assumes that
        # consecutive batches contain mostly the same requests. If batches
        # have low request overlap (e.g., alternating between two distinct
        # sets of requests), this optimization becomes very inefficient.
        for req_id in unscheduled_req_ids:
322
            req_index = self.input_batch.remove_request(req_id)
323
324
            assert req_index is not None
            removed_req_indices.append(req_index)
325

326
        req_ids_to_add: list[str] = []
327
        # Add new requests to the cached states.
328
329
330
        for new_req_data in scheduler_output.scheduled_new_reqs:
            req_id = new_req_data.req_id
            sampling_params = new_req_data.sampling_params
331
            if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
332
333
334
335
336
                generator = torch.Generator(device=self.device)
                generator.manual_seed(sampling_params.seed)
            else:
                generator = None

337
338
            self.requests[req_id] = CachedRequestState(
                req_id=req_id,
339
340
341
342
                prompt_token_ids=new_req_data.prompt_token_ids,
                prompt=new_req_data.prompt,
                mm_inputs=new_req_data.mm_inputs,
                mm_positions=new_req_data.mm_positions,
343
344
                sampling_params=sampling_params,
                generator=generator,
345
346
                block_ids=new_req_data.block_ids,
                num_computed_tokens=new_req_data.num_computed_tokens,
347
                output_token_ids=[],
348
                lora_request=new_req_data.lora_request,
349
            )
350
351

            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
352
            if self.uses_mrope:
353
354
                image_grid_thw = []
                video_grid_thw = []
Roger Wang's avatar
Roger Wang committed
355
                second_per_grid_ts = []
356
357
358
359
360
361
362
                for mm_input in self.requests[req_id].mm_inputs:
                    if mm_input.get("image_grid_thw") is not None:
                        image_grid_thw.extend(
                            mm_input["image_grid_thw"].tolist())
                    if mm_input.get("video_grid_thw") is not None:
                        video_grid_thw.extend(
                            mm_input["video_grid_thw"].tolist())
Roger Wang's avatar
Roger Wang committed
363
364
365
                    if mm_input.get("second_per_grid_ts") is not None:
                        second_per_grid_ts.extend(
                            mm_input["second_per_grid_ts"])
366
367
368
369
370
371
372

                hf_config = self.model_config.hf_config

                self.requests[req_id].mrope_positions, \
                    self.requests[req_id].mrope_position_delta = \
                    MRotaryEmbedding.get_input_positions_tensor(
                        self.requests[req_id].prompt_token_ids,
Roger Wang's avatar
Roger Wang committed
373
                        hf_config=hf_config,
374
375
                        image_grid_thw=image_grid_thw,
                        video_grid_thw=video_grid_thw,
Roger Wang's avatar
Roger Wang committed
376
                        second_per_grid_ts=second_per_grid_ts,
377
378
                    )

379
380
            req_ids_to_add.append(req_id)

381
382
383
        # Update the states of the running/resumed requests.
        for req_data in scheduler_output.scheduled_cached_reqs:
            req_id = req_data.req_id
384
385
            req_state = self.requests[req_id]

386
            # Update the cached states.
387
388
389
390
391
392
393
            num_computed_tokens = req_data.num_computed_tokens
            req_state.num_computed_tokens = num_computed_tokens
            # Add the sampled token(s) from the previous step (if any).
            # This doesn't include "unverified" tokens like spec decode tokens.
            num_new_tokens = (num_computed_tokens +
                              len(req_data.new_token_ids) -
                              req_state.num_tokens)
394
395
396
397
398
399
            if num_new_tokens == 1:
                # Avoid slicing list in most common case.
                req_state.output_token_ids.append(req_data.new_token_ids[-1])
            elif num_new_tokens > 0:
                req_state.output_token_ids.extend(
                    req_data.new_token_ids[-num_new_tokens:])
400
            # Update the block IDs.
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
            if not req_data.resumed_from_preemption:
                # Append the new blocks to the existing block IDs.
                req_state.block_ids.extend(req_data.new_block_ids)
            else:
                # The request is resumed from preemption.
                # Replace the existing block IDs with the new ones.
                req_state.block_ids = req_data.new_block_ids

            req_index = self.input_batch.req_id_to_index.get(req_id)
            if req_index is None:
                # The request is not in the persistent batch.
                # The request was either preempted and resumed later, or was not
                # scheduled in the previous step and needs to be added again.
                req_ids_to_add.append(req_id)
                continue

            # Update the persistent batch.
            self.input_batch.num_computed_tokens_cpu[req_index] = (
419
                num_computed_tokens)
420
421
            self.input_batch.block_table.append_row(req_data.new_block_ids,
                                                    req_index)
422
423
424
425
426
427
            # Add new_token_ids to token_ids_cpu.
            start_token_index = num_computed_tokens
            end_token_index = num_computed_tokens + len(req_data.new_token_ids)
            self.input_batch.token_ids_cpu[
                req_index,
                start_token_index:end_token_index] = req_data.new_token_ids
428
            self.input_batch.num_tokens_no_spec[req_index] = end_token_index
429
430
            # Add spec_token_ids to token_ids_cpu.
            spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
431
                req_id, ())
432
433
434
435
436
437
438
            if spec_token_ids:
                start_index = end_token_index
                end_token_index += len(spec_token_ids)
                self.input_batch.token_ids_cpu[
                    req_index, start_index:end_token_index] = spec_token_ids
            # NOTE(woosuk): `num_tokens` here may include spec decode tokens.
            self.input_batch.num_tokens[req_index] = end_token_index
439

440
441
        # Check if the batch has changed. If not, we can skip copying the
        # sampling metadata from CPU to GPU.
442
443
        batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0

444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
        # Add the new or resumed requests to the persistent batch.
        # The smaller empty indices are filled first.
        removed_req_indices = sorted(removed_req_indices, reverse=True)
        for req_id in req_ids_to_add:
            req_state = self.requests[req_id]
            if removed_req_indices:
                # Fill the empty index.
                req_index = removed_req_indices.pop()
            else:
                # Append to the end.
                req_index = None
            self.input_batch.add_request(req_state, req_index)

        # Condense the batched states if there are empty indices.
        if removed_req_indices:
            self.input_batch.condense(removed_req_indices)
460

461
462
        if batch_changed:
            self.input_batch.refresh_sampling_metadata()
463

464
    def _prepare_inputs(
465
466
        self,
        scheduler_output: "SchedulerOutput",
467
468
    ) -> tuple[FlashAttentionMetadata, torch.Tensor,
               Optional[SpecDecodeMetadata]]:
469
470
471
472
473
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        assert total_num_scheduled_tokens > 0
        num_reqs = self.input_batch.num_reqs
        assert num_reqs > 0

474
475
476
        # Some attention backends (namely MLA) may want to separate requests
        # based on if the attention computation will be compute-bound or
        # memory-bound. This gives them a hook to do that.
477
478
479
480
        modified_batch = self.attn_metadata_builder.reorder_batch(
            self.input_batch, scheduler_output)
        if modified_batch:
            self.input_batch.refresh_sampling_metadata()
481

482
483
        # OPTIMIZATION: Start copying the block table first.
        # This way, we can overlap the copy with the following CPU operations.
484
        self.input_batch.block_table.commit(num_reqs)
485
486

        # Get the number of scheduled tokens for each request.
487
488
489
490
        req_ids = self.input_batch.req_ids
        tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
        num_scheduled_tokens = np.array(tokens, dtype=np.int32)
        max_num_scheduled_tokens = max(tokens)
491
492
493

        # Get request indices.
        # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
494
495
        req_indices = np.repeat(self.arange_np[:num_reqs],
                                num_scheduled_tokens)
496
497
498

        # Get batched arange.
        # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
499
500
501
502
503
504
505
506
507
        # Equivalent to but faster than:
        # np.concatenate([np.arange(n) for n in num_scheduled_tokens])
        # Step 1. [2, 5, 3] -> [2, 7, 10]
        cu_num_tokens = np.cumsum(num_scheduled_tokens)
        # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
        cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
                                    num_scheduled_tokens)
        # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
508
509

        # Get positions.
510
        positions_np = self.positions_np[:total_num_scheduled_tokens]
511
512
513
514
        np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
               arange,
               out=positions_np)

515
516
        # Calculate M-RoPE positions.
        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
517
        if self.uses_mrope:
518
519
            self._calc_mrope_positions(scheduler_output)

520
521
522
523
        # Get token indices.
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
        # where M is the max_model_len.
524
525
        token_indices = (positions_np +
                         req_indices * self.input_batch.token_ids_cpu.shape[1])
526

527
528
529
530
        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
        torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
531
                           0,
532
533
                           torch.from_numpy(token_indices),
                           out=self.input_ids_cpu[:total_num_scheduled_tokens])
534
535

        # Calculate the slot mapping.
536
537
538
539
540
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
        # where K is the max_num_blocks_per_req and the block size is 2.
        # NOTE(woosuk): We can't simply use `token_indices // block_size` here
        # because M (max_model_len) is not necessarily divisible by block_size.
541
542
        block_table_indices = (req_indices * self.max_num_blocks_per_req +
                               positions_np // self.block_size)
543
544
        block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
        block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
545
546
547
548
        block_offsets = positions_np % self.block_size
        np.add(block_numbers * self.block_size,
               block_offsets,
               out=self.slot_mapping_np[:total_num_scheduled_tokens])
549
550

        # Prepare the attention metadata.
551
        self.query_start_loc_np[0] = 0
552
        self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
553

554
555
556
        self.seq_lens_np[:num_reqs] = (
            self.input_batch.num_computed_tokens_cpu[:num_reqs] +
            num_scheduled_tokens)
557
558
559
560

        # Copy the tensors to the GPU.
        self.input_ids[:total_num_scheduled_tokens].copy_(
            self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
561
        if self.uses_mrope:
562
563
564
565
566
567
568
569
570
            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
            self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
                self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
                non_blocking=True)
        else:
            # Common case (1D positions)
            self.positions[:total_num_scheduled_tokens].copy_(
                self.positions_cpu[:total_num_scheduled_tokens],
                non_blocking=True)
571

572
573
574
575
576
577
578
579
        # Prepare for cascade attention if enabled & beneficial.
        common_prefix_len = 0
        if self.cascade_attn_enabled:
            common_prefix_len = self._compute_cascade_attn_prefix_len(
                num_scheduled_tokens,
                scheduler_output.num_common_prefix_blocks,
            )

580
581
        attn_metadata = self.attn_metadata_builder.build(
            num_reqs=num_reqs,
582
            num_actual_tokens=total_num_scheduled_tokens,
583
            max_query_len=max_num_scheduled_tokens,
584
            common_prefix_len=common_prefix_len,
585
        )
586

587
588
        use_spec_decode = len(
            scheduler_output.scheduled_spec_decode_tokens) > 0
589
        if not use_spec_decode:
590
591
592
593
594
            # NOTE(woosuk): Due to chunked prefills, the batch may contain
            # partial requests. While we should not sample any token
            # from these partial requests, we do so for simplicity.
            # We will ignore the sampled tokens from the partial requests.
            # TODO: Support prompt logprobs.
595
            logits_indices = attn_metadata.query_start_loc[1:] - 1
596
597
598
599
600
601
602
603
604
605
606
607
608
609
            spec_decode_metadata = None
        else:
            # Get the number of draft tokens for each request.
            # Iterate over the dictionary rather than all requests since not all
            # requests have draft tokens.
            num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
            for req_id, draft_token_ids in (
                    scheduler_output.scheduled_spec_decode_tokens.items()):
                req_idx = self.input_batch.req_id_to_index[req_id]
                num_draft_tokens[req_idx] = len(draft_token_ids)

            spec_decode_metadata = self._calc_spec_decode_metadata(
                num_draft_tokens, cu_num_tokens)
            logits_indices = spec_decode_metadata.logits_indices
610

611
612
613
614
        # Hot-Swap lora model
        if self.lora_config:
            self.set_active_loras(self.input_batch, num_scheduled_tokens)

615
        return attn_metadata, logits_indices, spec_decode_metadata
616

617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
    def _compute_cascade_attn_prefix_len(
        self,
        num_scheduled_tokens: np.ndarray,
        num_common_prefix_blocks: int,
    ) -> int:
        """Compute the length of the common prefix for cascade attention.

        NOTE(woosuk): The common prefix length returned by this function
        represents the length used specifically for cascade attention, not the
        actual number of tokens shared between requests. When cascade attention
        is disabled (use_cascade=False), this function returns 0 even if
        requests share common tokens. Additionally, the common prefix length is
        truncated to a multiple of the block size and may be further truncated
        due to implementation details explained below.

        Args:
            num_scheduled_tokens: Number of tokens scheduled per request.
            num_common_prefix_blocks: Number of shared KV cache blocks.

        Returns:
            int: Length of common prefix in tokens.
        """
        common_prefix_len = num_common_prefix_blocks * self.block_size
        if common_prefix_len == 0:
            # Common case.
            return 0

        # NOTE(woosuk): Cascade attention uses two attention kernels: one
        # for the common prefix and the other for the rest. For the first
        # kernel, we concatenate all the query tokens (possibly from
        # different requests) and treat them as if they are from the same
        # request. Then, we use bi-directional attention to process the
        # common prefix in the KV cache. Importantly, this means that the
        # first kernel does not do any masking.

        # Consider the following example:
        # Request 1's input query: [D, E, X]
        # Request 1's kv cache: [A, B, C, D, E, X]
        # Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
        # Request 2's input query: [E, Y]
        # Request 2's kv cache: [A, B, C, D, E, Y]
        # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])

        # If we use [A, B, C, D, E] as the common prefix, then the
        # first kernel will compute the bi-directional attention between
        # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
        # However, this is wrong because D in Request 1 should not attend to
        # E in the common prefix (i.e., we need masking).
        # To avoid this, [A, B, C, D] should be the common prefix.
        # That is, the common prefix should be capped by the minimum
        # num_computed_tokens among the requests, and plus one to include
        # the first token of the query.

        # In practice, we use [A, B, C] as the common prefix, instead of
        # [A, B, C, D] (i.e., the common prefix is capped by the minimum
        # num_computed_tokens, without plus one).
        # This is because of an implementation detail: We want to always
        # use two kernels for cascade attention. Let's imagine:
        # Request 3's input query: [D]
        # Request 3's kv cache: [A, B, C, D]
677
        # Request 3's num_computed_tokens: 3 (i.e., [A, B, C])
678
679
680
681
682
683
684
685
686
687
688
689
        # If we use [A, B, C, D] as the common prefix for Request 1-3,
        # then Request 3 will be processed only by the first kernel,
        # and the second kernel will get an empty input. While this is not
        # a fundamental problem, our current implementation does not support
        # this case.
        num_reqs = len(num_scheduled_tokens)
        common_prefix_len = min(
            common_prefix_len,
            self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
        # common_prefix_len should be a multiple of the block size.
        common_prefix_len = (common_prefix_len // self.block_size *
                             self.block_size)
690
        use_cascade = self.attn_backend.use_cascade_attention(
691
692
693
694
            common_prefix_len=common_prefix_len,
            query_lens=num_scheduled_tokens,
            num_query_heads=self.num_query_heads,
            num_kv_heads=self.num_kv_heads,
695
            use_alibi=self.use_alibi,
696
            use_sliding_window=self.window_size is not None,
697
698
699
700
            num_sms=self.num_sms,
        )
        return common_prefix_len if use_cascade else 0

701
702
    def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
        mrope_pos_ptr = 0
703
        for index, req_id in enumerate(self.input_batch.req_ids):
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
            req = self.requests[req_id]
            assert req.mrope_positions is not None

            num_computed_tokens = \
                self.input_batch.num_computed_tokens_cpu[index]
            num_scheduled_tokens = \
                scheduler_output.num_scheduled_tokens[req_id]
            num_prompt_tokens = len(req.prompt_token_ids)

            if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
                prompt_part_len = max(0,
                                      num_prompt_tokens - num_computed_tokens)
                completion_part_len = max(
                    0, num_scheduled_tokens - prompt_part_len)
            else:
                prompt_part_len = num_scheduled_tokens
                completion_part_len = 0

            assert num_scheduled_tokens == prompt_part_len + completion_part_len

            if prompt_part_len > 0:
                # prompt's mrope_positions are pre-computed
                dst_start = mrope_pos_ptr
                dst_end = mrope_pos_ptr + prompt_part_len
                src_start = num_computed_tokens
                src_end = num_computed_tokens + prompt_part_len

                self.mrope_positions_cpu[:, dst_start:dst_end] = \
                    req.mrope_positions[:,src_start:src_end]

                mrope_pos_ptr += prompt_part_len

            if completion_part_len > 0:
                # compute completion's mrope_positions on-the-fly
                dst_start = mrope_pos_ptr
                dst_end = mrope_pos_ptr + completion_part_len

                self.mrope_positions_cpu[:, dst_start:dst_end] = \
                    MRotaryEmbedding.get_next_input_positions_tensor(
                        req.mrope_position_delta,
                        context_len=num_computed_tokens +
                        prompt_part_len,
                        seq_len=num_computed_tokens +
                        prompt_part_len +
                        completion_part_len,
                    )

                mrope_pos_ptr += completion_part_len

753
754
    def _calc_spec_decode_metadata(
        self,
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
        num_draft_tokens: np.ndarray,
        cu_num_scheduled_tokens: np.ndarray,
    ) -> SpecDecodeMetadata:
        # Inputs:
        # cu_num_scheduled_tokens:  [  4, 104, 107, 207, 209]
        # num_draft_tokens:         [  3,   0,   2,   0,   1]
        # Outputs:
        # cu_num_draft_tokens:      [  3,   3,   5,   5,   6]
        # logits_indices:           [  0,   1,   2,   3, 103, 104, 105, 106,
        #                            206, 207, 208]
        # target_logits_indices:    [  0,   1,   2,   5,   6,   9]
        # bonus_logits_indices:     [  3,   4,   7,   8,  10]

        # Compute the logits indices.
        # [4, 1, 3, 1, 2]
        num_sampled_tokens = num_draft_tokens + 1
        # Step 1. [4, 5, 8, 9, 11]
        cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
        total_num_sampled_tokens = cu_num_sampled_tokens[-1]
        # Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
        cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens,
                                    num_sampled_tokens)
        # Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
        arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets
        # Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
        logits_indices = np.repeat(
            cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens)
        # Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
        logits_indices += arange

        # Compute the bonus logits indices.
        bonus_logits_indices = cu_num_sampled_tokens - 1

        # Compute the draft logits indices.
        # [3, 3, 5, 5, 6]
        cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
        total_num_draft_tokens = cu_num_draft_tokens[-1]
        # [0, 0, 0, 3, 3, 5]
        cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens,
                                    num_draft_tokens)
        # [0, 1, 2, 0, 1, 0]
        arange = self.arange_np[:total_num_draft_tokens] - cumsums_offsets
        # [0, 0, 0, 5, 5, 9]
        target_logits_indices = np.repeat(
            cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens)
        # [0, 1, 2, 5, 6, 9]
        target_logits_indices += arange

        # TODO: Optimize the CPU -> GPU copy.
        cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
            self.device, non_blocking=True)
        logits_indices = torch.from_numpy(logits_indices).to(self.device,
                                                             non_blocking=True)
        target_logits_indices = torch.from_numpy(target_logits_indices).to(
            self.device, non_blocking=True)
        bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
811
812
            self.device, non_blocking=True)

813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
        # Compute the draft token ids.
        # draft_token_indices:      [  1,   2,   3, 105, 106, 208]
        draft_token_ids = self.input_ids[logits_indices]
        draft_token_ids = draft_token_ids[target_logits_indices + 1]

        metadata = SpecDecodeMetadata(
            draft_token_ids=draft_token_ids,
            num_draft_tokens=num_draft_tokens.tolist(),
            cu_num_draft_tokens=cu_num_draft_tokens,
            target_logits_indices=target_logits_indices,
            bonus_logits_indices=bonus_logits_indices,
            logits_indices=logits_indices,
        )
        return metadata

828
    def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
829
830
831
832
833
        scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
        if not scheduled_encoder_inputs:
            return

        # Batch the multi-modal inputs.
834
835
        mm_inputs = list[MultiModalKwargs]()
        req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
836
837
        for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
            req_state = self.requests[req_id]
838
839
840
841
842

            for mm_input_id in encoder_input_ids:
                mm_inputs.append(req_state.mm_inputs[mm_input_id])
                req_ids_pos.append(
                    (req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868

        # Batch mm inputs as much as we can: if a request in the batch has
        # multiple modalities or a different modality than the previous one,
        # we process it separately to preserve item order.
        # FIXME(ywang96): This is a hacky way to deal with multiple modalities
        # in the same batch while still being able to benefit from batching
        # multimodal inputs. The proper solution should be reordering the
        # encoder outputs.
        grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)

        encoder_outputs = []
        for grouped_mm_inputs in grouped_mm_inputs_list:
            batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
            batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
                                                           device=self.device)

            # Run the encoder.
            # `curr_group_outputs` is either of the following:
            # 1. A tensor of shape (num_items, feature_size, hidden_size)
            # in case feature_size is fixed across all multimodal items.
            # 2. A list or tuple (length: num_items) of tensors, each of shape
            # (feature_size, hidden_size) in case the feature size is dynamic
            # depending on the input multimodal items.
            curr_group_outputs = self.model.get_multimodal_embeddings(
                **batched_mm_inputs)

869
870
871
872
873
            sanity_check_mm_encoder_outputs(
                curr_group_outputs,
                expected_num_items=len(grouped_mm_inputs),
            )

874
875
            for output in curr_group_outputs:
                encoder_outputs.append(output)
876
877

        # Cache the encoder outputs.
878
879
880
881
        for (req_id, input_id, pos_info), output in zip(
                req_ids_pos,
                encoder_outputs,
        ):
882
883
884
            if req_id not in self.encoder_cache:
                self.encoder_cache[req_id] = {}

885
886
887
888
889
890
            self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
                output,
                is_embed=pos_info.is_embed,
            )

    def _gather_mm_embeddings(
891
892
        self,
        scheduler_output: "SchedulerOutput",
893
    ) -> list[torch.Tensor]:
894
        mm_embeds: list[torch.Tensor] = []
895
        for req_id in self.input_batch.req_ids:
896
897
898
899
900
901
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
                req_id]
            req_state = self.requests[req_id]
            num_computed_tokens = req_state.num_computed_tokens
            mm_positions = req_state.mm_positions
            for i, pos_info in enumerate(mm_positions):
902
903
                start_pos = pos_info.offset
                num_encoder_tokens = pos_info.length
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924

                # The encoder output is needed if the two ranges overlap:
                # [num_computed_tokens,
                #  num_computed_tokens + num_scheduled_tokens) and
                # [start_pos, start_pos + num_encoder_tokens)
                if start_pos >= num_computed_tokens + num_scheduled_tokens:
                    # The encoder output is not needed in this step.
                    break
                if start_pos + num_encoder_tokens <= num_computed_tokens:
                    # The encoder output is already processed and stored
                    # in the decoder's KV cache.
                    continue

                start_idx = max(num_computed_tokens - start_pos, 0)
                end_idx = min(
                    num_computed_tokens - start_pos + num_scheduled_tokens,
                    num_encoder_tokens)
                assert start_idx < end_idx
                assert req_id in self.encoder_cache
                assert i in self.encoder_cache[req_id]
                encoder_output = self.encoder_cache[req_id][i]
925
926
927
928
929
930
931
932
933
934

                if (is_embed := pos_info.is_embed) is not None:
                    is_embed = is_embed[start_idx:end_idx]

                mm_embeds_item = gather_mm_placeholders(
                    encoder_output[start_idx:end_idx],
                    is_embed=is_embed,
                )
                mm_embeds.append(mm_embeds_item)
        return mm_embeds
935

936
937
938
    def get_model(self) -> nn.Module:
        return self.model

939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
    def apply_grammar_bitmask(
        self,
        scheduler_output: "SchedulerOutput",
        logits: torch.Tensor,
    ):
        # Serialization of np.ndarray is much more efficient than a tensor,
        # so we receive it in that format.
        grammar_bitmask = scheduler_output.grammar_bitmask
        if grammar_bitmask is None:
            return

        # We receive the structured output bitmask from the scheduler, but the
        # indices of the requests in the batch may not match the indices of
        # the bitmask since the scheduler doesn't know how the gpu runner is
        # ordering the requests in the batch. We need to sort the bitmask to
        # match the order of the requests used here.
        struct_out_req_batch_indices: dict[str, int] = {}
        indices_match = True
        for req_id in self.input_batch.req_ids:
            mask_index = scheduler_output.structured_output_request_ids.get(
                req_id)
            if mask_index is None:
                # not a structured output request
                continue
            batch_index = self.input_batch.req_id_to_index[req_id]
            if batch_index != mask_index:
                indices_match = False
            struct_out_req_batch_indices[req_id] = batch_index

        if not indices_match:
            # Sort the bitmask to match the order of the requests
            sorted_bitmask = np.zeros_like(grammar_bitmask)
            for req_id, batch_index in struct_out_req_batch_indices.items():
                orig_index = scheduler_output.structured_output_request_ids[
                    req_id]
                sorted_bitmask[batch_index] = grammar_bitmask[orig_index]
            grammar_bitmask = sorted_bitmask

        grammar_bitmask = torch.from_numpy(grammar_bitmask)

        # TODO: compatibility with spec decode
        xgr.apply_token_bitmask_inplace(
            logits,
            grammar_bitmask.to(self.device, non_blocking=True),
            indices=list(struct_out_req_batch_indices.values()),
        )

986
987
988
989
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
990
        intermediate_tensors: Optional[IntermediateTensors] = None,
991
    ) -> Union[ModelRunnerOutput, torch.Tensor]:
992
        self._update_states(scheduler_output)
993
        if not scheduler_output.total_num_scheduled_tokens:
994
            # Return empty ModelRunnerOutput if there's no work to do.
995
            return EMPTY_MODEL_RUNNER_OUTPUT
996
997

        # Prepare the decoder inputs.
998
999
        attn_metadata, logits_indices, spec_decode_metadata = (
            self._prepare_inputs(scheduler_output))
1000
1001
1002
1003
1004
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        if (self.use_cuda_graph
                and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
            # Use piecewise CUDA graphs.
            # Add padding to the batch size.
1005
            num_input_tokens = self.vllm_config.pad_for_cudagraph(
1006
1007
1008
1009
                num_scheduled_tokens)
        else:
            # Eager mode.
            num_input_tokens = num_scheduled_tokens
1010
1011
        attn_metadata.num_input_tokens = num_input_tokens

1012
1013
1014
1015
1016
1017
1018
1019
1020
        # _prepare_inputs may reorder the batch, so we must gather multi
        # modal outputs after that to ensure the correct order
        if self.is_multimodal_model:
            # Run the multimodal encoder if any.
            self._execute_mm_encoder(scheduler_output)
            mm_embeds = self._gather_mm_embeddings(scheduler_output)
        else:
            mm_embeds = []

1021
1022
1023
1024
1025
        if self.is_multimodal_model:
            # NOTE(woosuk): To unify token ids and soft tokens (vision
            # embeddings), we always use embeddings (rather than token ids)
            # as input to the multimodal model, even when the input is text.
            input_ids = self.input_ids[:num_scheduled_tokens]
1026
            if mm_embeds:
1027
                inputs_embeds = self.model.get_input_embeddings(
1028
                    input_ids, mm_embeds)
1029
1030
1031
1032
1033
1034
            else:
                inputs_embeds = self.model.get_input_embeddings(input_ids)
            # TODO(woosuk): Avoid the copy. Optimize.
            self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
            input_ids = None
1035
        else:
1036
1037
1038
1039
1040
1041
            # For text-only models, we use token ids as input.
            # While it is possible to use embeddings as input just like the
            # multimodal models, it is not desirable for performance since
            # then the embedding layer is not included in the CUDA graph.
            input_ids = self.input_ids[:num_input_tokens]
            inputs_embeds = None
1042
1043
1044
1045
        if self.uses_mrope:
            positions = self.mrope_positions[:, :num_input_tokens]
        else:
            positions = self.positions[:num_input_tokens]
1046

1047
1048
1049
        if get_pp_group().is_first_rank:
            intermediate_tensors = None
        else:
1050
1051
1052
1053
1054
            assert intermediate_tensors is not None
            assert self.intermediate_tensors is not None
            for k, v in intermediate_tensors.items():
                self.intermediate_tensors[k][:num_input_tokens].copy_(
                    v[:num_input_tokens], non_blocking=True)
1055
1056
1057
1058
1059
            intermediate_tensors = IntermediateTensors({
                k: v[:num_input_tokens]
                for k, v in self.intermediate_tensors.items()
            })

1060
1061
        # Run the decoder.
        # Use persistent buffers for CUDA graphs.
1062
        with set_forward_context(attn_metadata, self.vllm_config):
1063
            hidden_states = self.model(
1064
                input_ids=input_ids,
1065
                positions=positions,
1066
                intermediate_tensors=intermediate_tensors,
1067
                inputs_embeds=inputs_embeds,
1068
            )
1069
        if not get_pp_group().is_last_rank:
1070
            # For mid-pipeline stages, return the hidden states.
1071
            return hidden_states
1072

1073
        hidden_states = hidden_states[:num_scheduled_tokens]
1074
1075
        sample_hidden_states = hidden_states[logits_indices]
        logits = self.model.compute_logits(sample_hidden_states, None)
1076

1077
1078
1079
1080
        # Apply structured output bitmasks if present
        if scheduler_output.grammar_bitmask is not None:
            self.apply_grammar_bitmask(scheduler_output, logits)

1081
        # Sample the next token and get logprobs if needed.
1082
        sampling_metadata = self.input_batch.sampling_metadata
1083
        if spec_decode_metadata is None:
1084
1085
1086
1087
1088
            sampler_output = self.model.sample(
                logits=logits,
                sampling_metadata=sampling_metadata,
            )
        else:
1089
1090
1091
1092
            # When indexing with a tensor (bonus_logits_indices), PyTorch
            # creates a new tensor with separate storage from the original
            # logits tensor. This means any in-place operations on bonus_logits
            # won't affect the original logits tensor.
1093
            bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
1094
            sampler_output = self.model.sample(
1095
                logits=bonus_logits,
1096
1097
1098
                sampling_metadata=sampling_metadata,
            )
            bonus_token_ids = sampler_output.sampled_token_ids
1099

1100
1101
1102
            # Just like `bonus_logits`, `target_logits` is a new tensor with
            # separate storage from the original `logits` tensor. Therefore,
            # it is safe to update `target_logits` in place.
1103
            target_logits = logits[spec_decode_metadata.target_logits_indices]
1104
            output_token_ids = self.rejection_sampler(
1105
                spec_decode_metadata,
1106
                None,  # draft_probs
1107
                target_logits,
1108
                bonus_token_ids,
1109
1110
                sampling_metadata,
            )
1111
            sampler_output.sampled_token_ids = output_token_ids
1112
1113
1114

        # TODO(woosuk): The following loop can be slow since it iterates over
        # the requests one by one. Optimize.
1115
1116
        discard_sampled_tokens_req_indices = []
        for i, req_id in enumerate(self.input_batch.req_ids):
1117
1118
1119
            req_state = self.requests[req_id]
            seq_len = (req_state.num_computed_tokens +
                       scheduler_output.num_scheduled_tokens[req_id])
1120
            if seq_len < req_state.num_tokens:
1121
                # Ignore the sampled token for partial prefills.
1122
                # Rewind the generator state as if the token was not sampled.
1123
                # This relies on cuda-specific torch-internal impl details
1124
1125
1126
1127
1128
1129
                generator = self.input_batch.generators.get(i)
                if generator is not None:
                    generator.set_offset(generator.get_offset() - 4)
                # Record the index of the request that should not be sampled,
                # so that we could clear the sampled tokens before returning.
                discard_sampled_tokens_req_indices.append(i)
1130

1131
1132
        # NOTE: GPU -> CPU Sync happens here.
        # Move as many CPU operations as possible before this sync point.
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
        logprobs_tensors = sampler_output.logprobs_tensors
        logprobs_lists = logprobs_tensors.tolists() \
            if logprobs_tensors is not None else None

        # Compute prompt logprobs if needed.
        prompt_logprobs_dict = self._get_prompt_logprobs_dict(
            hidden_states,
            scheduler_output,
        )

1143
        # Get the valid generated tokens.
1144
1145
1146
        sampled_token_ids = sampler_output.sampled_token_ids
        max_gen_len = sampled_token_ids.shape[-1]
        if max_gen_len == 1:
1147
            # No spec decode tokens.
1148
1149
            valid_sampled_token_ids = sampled_token_ids.tolist()
        else:
1150
            # Includes spec decode tokens.
1151
            valid_sampled_token_ids = self.rejection_sampler.parse_output(
1152
1153
1154
                sampled_token_ids,
                self.input_batch.vocab_size,
            )
1155
1156
1157
        # Mask out the sampled tokens that should not be sampled.
        for i in discard_sampled_tokens_req_indices:
            valid_sampled_token_ids[i].clear()
1158

1159
        if not self.use_spec_decode:
1160
            # Speculative decoding is not enabled.
1161
            spec_token_ids = None
1162
1163
        elif self.speculative_config.method == "ngram":
            assert isinstance(self.drafter, NgramProposer)
1164
            spec_token_ids = self.generate_draft_token_ids(
1165
                valid_sampled_token_ids, sampling_metadata)
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
        elif self.speculative_config.method == "eagle":
            assert isinstance(self.drafter, EagleProposer)
            # TODO(woosuk): Refactor the loop.
            next_token_ids: list[int] = []
            for i, token_ids in enumerate(valid_sampled_token_ids):
                if token_ids:
                    # Common case.
                    next_token_id = token_ids[-1]
                else:
                    # Partial prefill (rare case).
                    # Get the next token id from the request state.
                    req_id = self.input_batch.req_ids[i]
                    req_state = self.requests[req_id]
                    seq_len = (req_state.num_computed_tokens +
                               scheduler_output.num_scheduled_tokens[req_id])
                    next_token_id = req_state.get_token_id(seq_len)
                next_token_ids.append(next_token_id)
            next_token_ids = torch.tensor(next_token_ids,
                                          dtype=torch.int32,
                                          device=self.device)

            if spec_decode_metadata is None:
                # input_ids can be None for multimodal models.
1189
1190
1191
                # We need to slice token_ids, positions, and hidden_states
                # because the eagle head does not use cuda graph and should
                # not include padding.
1192
                target_token_ids = self.input_ids[:num_scheduled_tokens]
1193
1194
                target_positions = positions[:num_scheduled_tokens]
                target_hidden_states = hidden_states[:num_scheduled_tokens]
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
                target_slot_mapping = attn_metadata.slot_mapping
                cu_num_tokens = attn_metadata.query_start_loc
            else:
                # TODO(woosuk): Refactor this.
                num_draft_tokens = spec_decode_metadata.num_draft_tokens
                num_rejected_tokens = [
                    n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
                    for i, n in enumerate(num_draft_tokens)
                ]
                num_rejected_tokens = torch.tensor(
                    num_rejected_tokens,
                    dtype=torch.int32,
                    device=self.device,
                )
                cu_num_tokens, token_indices = self.drafter.prepare_inputs(
                    attn_metadata.query_start_loc,
                    num_rejected_tokens,
                )
                target_token_ids = self.input_ids[token_indices]
                target_positions = positions[token_indices]
                target_hidden_states = hidden_states[token_indices]
                target_slot_mapping = attn_metadata.slot_mapping[token_indices]

            draft_token_ids, draft_probs = self.drafter.propose(
                target_token_ids=target_token_ids,
                target_positions=target_positions,
                target_hidden_states=target_hidden_states,
                target_slot_mapping=target_slot_mapping,
                next_token_ids=next_token_ids,
                cu_num_tokens=cu_num_tokens,
                block_table=attn_metadata.block_table,
                sampling_metadata=sampling_metadata,
            )
            spec_token_ids = draft_token_ids.tolist()
            # TODO(woosuk): Cache draft_probs and use it for rejection sampling
            # in the next step.
            del draft_probs
1232

1233
        return ModelRunnerOutput(
1234
            req_ids=self.input_batch.req_ids,
1235
            req_id_to_index=self.input_batch.req_id_to_index,
1236
            sampled_token_ids=valid_sampled_token_ids,
1237
            spec_token_ids=spec_token_ids,
1238
1239
            logprobs=logprobs_lists,
            prompt_logprobs_dict=prompt_logprobs_dict,
1240
1241
        )

1242
1243
    def generate_draft_token_ids(
        self,
1244
        sampled_token_ids: list[list[int]],
1245
        sampling_metadata: SamplingMetadata,
1246
    ) -> list[list[int]]:
1247
        # TODO(woosuk): Optimize.
1248
        draft_token_ids: list[list[int]] = []
1249
1250
1251
        for i, sampled_ids in enumerate(sampled_token_ids):
            num_sampled_ids = len(sampled_ids)
            if not num_sampled_ids:
1252
1253
1254
1255
                # Skip speculative decoding.
                draft_token_ids.append([])
                continue

1256
1257
1258
1259
1260
1261
            # Skip requests that require top-p, top-k, etc.
            req_id = self.input_batch.req_ids[i]
            if not is_spec_decode_supported(req_id, self.input_batch):
                draft_token_ids.append([])
                continue

1262
1263
            # Add sampled_token_ids to token_ids_cpu.
            start_idx = self.input_batch.num_tokens_no_spec[i]
1264
1265
            end_idx = start_idx + num_sampled_ids
            self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
1266
            drafter_output = self.drafter.propose(
1267
                self.input_batch.token_ids_cpu[i, :end_idx])
1268
1269
1270
1271
1272
1273
            if drafter_output is None or len(drafter_output) == 0:
                draft_token_ids.append([])
            else:
                draft_token_ids.append(drafter_output.tolist())
        return draft_token_ids

1274
1275
1276
    def load_model(self) -> None:
        logger.info("Starting to load model %s...", self.model_config.model)
        with DeviceMemoryProfiler() as m:  # noqa: SIM117
1277
            time_before_load = time.perf_counter()
Joe Runde's avatar
Joe Runde committed
1278
            self.model = get_model(vllm_config=self.vllm_config)
1279
1280
1281
1282
1283
1284
            if self.lora_config:
                self.model = self.load_lora_model(self.model,
                                                  self.model_config,
                                                  self.scheduler_config,
                                                  self.lora_config,
                                                  self.device)
1285
1286
1287
            if hasattr(self, "drafter"):
                logger.info("Loading drafter model...")
                self.drafter.load_model(self.model)
1288
            time_after_load = time.perf_counter()
1289
        self.model_memory_usage = m.consumed_memory
1290
1291
        logger.info("Model loading took %.4f GiB and %.6f seconds",
                    self.model_memory_usage / GiB_bytes,
1292
                    time_after_load - time_before_load)
1293

1294
1295
1296
1297
    def _get_prompt_logprobs_dict(
        self,
        hidden_states: torch.Tensor,
        scheduler_output: "SchedulerOutput",
1298
    ) -> dict[str, Optional[LogprobsTensors]]:
1299
1300
1301
1302
        num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
        if not num_prompt_logprobs_dict:
            return {}

1303
        in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
1304
        prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318

        # Since prompt logprobs are a rare feature, prioritize simple,
        # maintainable loop over optimal performance.
        completed_prefill_reqs = []
        for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():

            num_tokens = scheduler_output.num_scheduled_tokens[req_id]

            # Get metadata for this request.
            request = self.requests[req_id]
            num_prompt_tokens = len(request.prompt_token_ids)
            prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
                self.device, non_blocking=True)

1319
1320
1321
1322
1323
1324
1325
1326
1327
            # Set up target LogprobsTensors object.
            logprobs_tensors = in_progress_dict.get(req_id)
            if not logprobs_tensors:
                # Create empty logprobs CPU tensors for the entire prompt.
                # If chunked, we'll copy in slice by slice.
                logprobs_tensors = LogprobsTensors.empty_cpu(
                    num_prompt_tokens - 1, num_prompt_logprobs + 1)
                in_progress_dict[req_id] = logprobs_tensors

1328
            # Determine number of logits to retrieve.
1329
1330
            start_idx = request.num_computed_tokens
            start_tok = start_idx + 1
1331
            num_remaining_tokens = num_prompt_tokens - start_tok
1332
            if num_tokens <= num_remaining_tokens:
1333
                # This is a chunk, more tokens remain.
1334
1335
1336
                # In the == case, there are no more prompt logprobs to produce
                # but we want to defer returning them to the next step where we
                # have new generated tokens to return.
1337
1338
1339
1340
1341
                num_logits = num_tokens
            else:
                # This is the last chunk of prompt tokens to return.
                num_logits = num_remaining_tokens
                completed_prefill_reqs.append(req_id)
1342
1343
1344
1345
1346
1347
1348
                prompt_logprobs_dict[req_id] = logprobs_tensors

            if num_logits <= 0:
                # This can happen for the final chunk if we prefilled exactly
                # (num_prompt_tokens - 1) tokens for this request in the prior
                # step. There are no more prompt logprobs to produce.
                continue
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368

            # Get the logits corresponding to this req's prompt tokens.
            # If this is a partial request (i.e. chunked prefill),
            # then there is prompt logprob generated for each index.
            req_idx = self.input_batch.req_id_to_index[req_id]
            offset = self.query_start_loc_np[req_idx].item()
            prompt_hidden_states = hidden_states[offset:offset + num_logits]
            logits = self.model.compute_logits(prompt_hidden_states, None)

            # Get the "target" tokens for each index. For prompt at index i,
            # the token at prompt index i+1 is the "sampled" token we want
            # to gather the logprob for.
            tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]

            # Compute prompt logprobs.
            logprobs = self.model.sampler.compute_logprobs(logits)
            token_ids, logprobs, ranks = self.model.sampler.gather_logprobs(
                logprobs, num_prompt_logprobs, tgt_token_ids)

            # Transfer GPU->CPU async.
1369
1370
1371
1372
1373
1374
1375
            chunk_slice = slice(start_idx, start_idx + num_logits)
            logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
                token_ids, non_blocking=True)
            logprobs_tensors.logprobs[chunk_slice].copy_(logprobs,
                                                         non_blocking=True)
            logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
                ranks, non_blocking=True)
1376
1377
1378
1379
1380

        # Remove requests that have completed prefill from the batch
        # num_prompt_logprobs_dict.
        for req_id in completed_prefill_reqs:
            del num_prompt_logprobs_dict[req_id]
1381
            del in_progress_dict[req_id]
1382
1383

        # Must synchronize the non-blocking GPU->CPU transfers.
1384
1385
        if prompt_logprobs_dict:
            torch.cuda.synchronize()
1386
1387
1388

        return prompt_logprobs_dict

1389
1390
1391
1392
1393
    @torch.inference_mode()
    def _dummy_run(
        self,
        num_tokens: int,
    ) -> torch.Tensor:
1394

1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
        # Set num_scheduled_tokens based on num_tokens and max_num_seqs
        # for dummy run with LoRA so that the num_reqs collectively
        # has num_tokens in total.
        assert num_tokens <= self.scheduler_config.max_num_batched_tokens
        max_num_reqs = self.scheduler_config.max_num_seqs
        num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens
        min_tokens_per_req = num_tokens // num_reqs
        num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
        num_scheduled_tokens_list[-1] += num_tokens % num_reqs
        assert sum(num_scheduled_tokens_list) == num_tokens
        assert len(num_scheduled_tokens_list) == num_reqs
        num_scheduled_tokens = np.array(num_scheduled_tokens_list,
                                        dtype=np.int32)
1408

1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
        with self.maybe_dummy_run_with_lora(self.lora_config,
                                            num_scheduled_tokens):
            model = self.model
            if self.is_multimodal_model:
                input_ids = None
                inputs_embeds = self.inputs_embeds[:num_tokens]
            else:
                input_ids = self.input_ids[:num_tokens]
                inputs_embeds = None
            if self.uses_mrope:
                positions = self.mrope_positions[:, :num_tokens]
            else:
                positions = self.positions[:num_tokens]

            if get_pp_group().is_first_rank:
                intermediate_tensors = None
            else:
                if self.intermediate_tensors is None:
                    self.intermediate_tensors = (
                        self.model.make_empty_intermediate_tensors(
                            batch_size=self.max_num_tokens,
                            dtype=self.model_config.dtype,
                            device=self.device))
                intermediate_tensors = IntermediateTensors({
                    k: v[:num_tokens]
                    for k, v in self.intermediate_tensors.items()
                })

            with set_forward_context(None,
                                     self.vllm_config,
                                     num_tokens=num_tokens):
                hidden_states = model(
                    input_ids=input_ids,
                    positions=positions,
                    intermediate_tensors=intermediate_tensors,
                    inputs_embeds=inputs_embeds,
                )

        logit_indices = np.cumsum(num_scheduled_tokens) - 1
        return hidden_states[logit_indices]

    @torch.inference_mode()
    def _dummy_sampler_run(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:

        logits = self.model.compute_logits(hidden_states, None)
        num_reqs = logits.size(0)

        dummy_tensors = lambda v: torch.full(
            (num_reqs, ), v, device=self.device)

        dummy_metadata = SamplingMetadata(
            temperature=dummy_tensors(0.5),
            all_greedy=False,
            all_random=False,
            top_p=dummy_tensors(0.9),
            top_k=dummy_tensors(logits.size(1) - 1),
            min_p=None,
            generators={},
            max_num_logprobs=None,
            no_penalties=True,
            prompt_token_ids=None,
            frequency_penalties=dummy_tensors(0.1),
            presence_penalties=dummy_tensors(0.1),
            repetition_penalties=dummy_tensors(0.1),
            output_token_ids=[[] for _ in range(num_reqs)],
            min_tokens={},
            logit_bias=[None for _ in range(num_reqs)],
            allowed_token_ids_mask=None,
            bad_words_token_ids={},
        )
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
        try:
            sampler_output = self.model.sample(
                logits=logits, sampling_metadata=dummy_metadata)
        except RuntimeError as e:
            if 'out of memory' in str(e):
                raise RuntimeError(
                    "CUDA out of memory occurred when warming up sampler with "
                    f"{num_reqs} dummy requests. Please try lowering "
                    "`max_num_seqs` or `gpu_memory_utilization` when "
                    "initializing the engine.") from e
            else:
                raise e
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
        if self.use_spec_decode:
            draft_token_ids = [[0] for _ in range(num_reqs)]
            dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy(
                draft_token_ids, self.device)

            num_tokens = sum(len(ids) for ids in draft_token_ids)
            # draft_probs = torch.randn(
            #     num_tokens, logits.shape[-1], device=self.device,
            #     dtype=logits.dtype)
            draft_probs = None
            target_logits = torch.randn(num_tokens,
                                        logits.shape[-1],
                                        device=self.device,
                                        dtype=logits.dtype)
            # NOTE(woosuk): Here, we should use int32 because the sampler uses
            # int32 for bonus_token_ids. If the dtype mismatches, re-compilation
            # will occur at runtime.
            bonus_token_ids = torch.zeros(num_reqs,
                                          device=self.device,
                                          dtype=torch.int32)
            self.rejection_sampler(
                dummy_spec_decode_metadata,
                draft_probs,
                target_logits,
                bonus_token_ids,
                dummy_metadata,
            )
1521
        return sampler_output
1522
1523

    def profile_run(self) -> None:
1524
        # Profile with multimodal encoder & encoder cache.
1525
1526
1527
        # TODO: handle encoder-decoder models once we support them.
        if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
                and self.encoder_cache_size > 0):
1528

1529
            # NOTE: Currently model is profiled with a single non-text
1530
1531
            # modality with the max possible input tokens even when
            # it supports multiple.
1532
1533
            max_tokens_by_modality_dict = self.mm_registry \
                .get_max_tokens_per_item_by_nonzero_modality(self.model_config)
1534
1535
1536
1537
            dummy_data_modality, max_tokens_per_mm_item = max(
                max_tokens_by_modality_dict.items(), key=lambda item: item[1])

            # Check how many items of this modality can be supported by
1538
1539
1540
1541
1542
1543
            # the encoder budget.
            encoder_budget = min(self.max_num_encoder_input_tokens,
                                 self.encoder_cache_size)

            max_num_mm_items_encoder_budget = cdiv(encoder_budget,
                                                   max_tokens_per_mm_item)
1544
1545
1546

            # Check how many items of this modality can be supported by
            # the decoder budget.
1547
1548
            max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
                self.model_config)[dummy_data_modality]
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558

            # NOTE: We do not consider max_num_batched_tokens on purpose
            # because the multimodal embeddings can be generated in advance
            # and chunked prefilled.
            max_num_mm_items_decoder_budget = self.max_num_reqs * \
                max_mm_items_per_req

            max_num_mm_items = min(max_num_mm_items_encoder_budget,
                                   max_num_mm_items_decoder_budget)

1559
1560
1561
1562
1563
1564
            logger.info(
                "Encoder cache will be initialized with a budget of %s tokens,"
                " and profiled with %s %s items of the maximum feature size.",
                encoder_budget, max_num_mm_items, dummy_data_modality)

            # Create dummy batch of multimodal inputs.
1565
            dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data(
1566
1567
                model_config=self.model_config,
                seq_len=self.max_num_tokens,
1568
1569
1570
1571
                mm_counts={
                    dummy_data_modality: 1
                },
            ).multi_modal_data
1572

1573
            batched_dummy_mm_inputs = MultiModalKwargs.batch(
1574
                [dummy_mm_kwargs] * max_num_mm_items)
1575
1576
1577
1578
1579
1580
            batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
                batched_dummy_mm_inputs, device=self.device)

            # Run multimodal encoder.
            dummy_encoder_outputs = self.model.get_multimodal_embeddings(
                **batched_dummy_mm_inputs)
1581
1582
1583
1584
1585

            sanity_check_mm_encoder_outputs(
                dummy_encoder_outputs,
                expected_num_items=max_num_mm_items,
            )
1586
1587
1588
1589

            # Cache the dummy encoder outputs.
            self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))

1590
1591
1592
1593
1594
1595
1596
1597
        hidden_states = self._dummy_run(self.max_num_tokens)
        if get_pp_group().is_last_rank:
            sampler_output = self._dummy_sampler_run(hidden_states)
        else:
            sampler_output = None
        torch.cuda.synchronize()
        del hidden_states, sampler_output
        self.encoder_cache.clear()
1598
        gc.collect()
1599
1600

    def capture_model(self) -> None:
1601
1602
        if not self.use_cuda_graph:
            logger.warning(
1603
                "Skipping CUDA graph capture. Please add "
1604
                "-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE)
1605
1606
1607
1608
1609
            return

        start_time = time.perf_counter()
        start_free_gpu_memory = torch.cuda.mem_get_info()[0]

1610
1611
1612
        # Trigger CUDA graph capture for specific shapes.
        # Capture the large shapes first so that the smaller shapes
        # can reuse the memory pool allocated for the large shapes.
1613
        with graph_capture(device=self.device):
1614
            for num_tokens in reversed(self.cudagraph_batch_sizes):
1615
1616
                for _ in range(self.vllm_config.compilation_config.
                               cudagraph_num_of_warmups):
1617
1618
                    self._dummy_run(num_tokens)
                self._dummy_run(num_tokens)
1619
1620
1621
1622
1623
1624
1625
1626

        end_time = time.perf_counter()
        end_free_gpu_memory = torch.cuda.mem_get_info()[0]
        elapsed_time = end_time - start_time
        cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
        # This usually takes 5~20 seconds.
        logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
                    elapsed_time, cuda_graph_size / (1 << 30))
1627

1628
1629
1630
1631
    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize KV cache based on `kv_cache_config`.
        Args:
1632
            kv_cache_config: Configuration for the KV cache, including the KV
1633
1634
            cache size of each layer
        """
1635
        if len(kv_cache_config.kv_cache_groups) > 1:
1636
1637
1638
1639
            raise NotImplementedError(
                "Hybrid models with more than one KV cache type are not "
                "supported yet.")

1640
        kv_caches: dict[str, torch.Tensor] = {}
1641

1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
        for kv_cache_group in kv_cache_config.kv_cache_groups:
            kv_cache_spec = kv_cache_group.kv_cache_spec
            for layer_name in kv_cache_group.layer_names:
                tensor_config = kv_cache_config.tensors[layer_name]
                assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
                num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
                # `num_blocks` is the number of blocks the model runner can use.
                # `kv_cache_config.num_blocks` is the number of blocks that
                # KVCacheManager may allocate.
                # Since different GPUs may have different number of layers and
                # different memory capacities, `num_blocks` can be different on
                # different GPUs, and `kv_cache_config.num_blocks` is set to
                # the min of all `num_blocks`. Verify it here.
                assert num_blocks >= kv_cache_config.num_blocks
1656
                if isinstance(kv_cache_spec, AttentionSpec):
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
                    kv_cache_shape = self.attn_backend.get_kv_cache_shape(
                        num_blocks, kv_cache_spec.block_size,
                        kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
                    dtype = kv_cache_spec.dtype
                    kv_caches[layer_name] = torch.zeros(kv_cache_shape,
                                                        dtype=dtype,
                                                        device=self.device)
                else:
                    # TODO: add new branches when introducing more types of
                    # KV cache specs.
                    raise ValueError("Unknown KV cache spec type.")
1668

1669
        bind_kv_cache(
1670
            kv_caches,
1671
            self.vllm_config.compilation_config.static_forward_context,
1672
1673
            self.kv_caches)

1674
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
1675
        """
1676
        Generates the KVCacheSpec by parsing the kv cache format from each
1677
1678
        Attention module in the static forward context.
        Returns:
1679
            KVCacheSpec: A dictionary mapping layer names to their KV cache
1680
1681
1682
1683
1684
            format. Layers that do not need KV cache are not included.
        """

        forward_ctx = self.vllm_config.compilation_config.static_forward_context
        block_size = self.vllm_config.cache_config.block_size
1685
        use_mla = self.vllm_config.model_config.use_mla
1686
        kv_cache_spec: dict[str, KVCacheSpec] = {}
1687
        for layer_name, attn_module in forward_ctx.items():
1688
1689
1690
            if isinstance(attn_module, FusedMoE):
                continue

1691
            # TODO: Support other attention modules, e.g., sliding window,
1692
            # cross-attention
1693
1694
            assert isinstance(attn_module, Attention)
            if attn_module.attn_type == AttentionType.DECODER:
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
                if attn_module.sliding_window is not None:
                    kv_cache_spec[layer_name] = SlidingWindowSpec(
                        block_size=block_size,
                        num_kv_heads=attn_module.num_kv_heads,
                        head_size=attn_module.head_size,
                        dtype=self.kv_cache_dtype,
                        sliding_window=attn_module.sliding_window,
                        use_mla=use_mla)
                else:
                    kv_cache_spec[layer_name] = FullAttentionSpec(
                        block_size=block_size,
                        num_kv_heads=attn_module.num_kv_heads,
                        head_size=attn_module.head_size,
                        dtype=self.kv_cache_dtype,
                        use_mla=use_mla)
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
            elif attn_module.attn_type in (AttentionType.ENCODER,
                                           AttentionType.ENCODER_ONLY):
                # encoder-only attention does not need KV cache.
                continue
            elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
                raise NotImplementedError
            else:
                raise ValueError(
                    f"Unknown attention type: {attn_module.attn_type}")

        return kv_cache_spec