gpu_model_runner.py 85.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import gc
4
import time
5
import weakref
6
from typing import TYPE_CHECKING, Optional, Union
7
8
9
10
11
12

import numpy as np
import torch
import torch.distributed
import torch.nn as nn

13
from vllm.attention import AttentionType, get_attn_backend
14
from vllm.attention.layer import Attention
15
from vllm.attention.utils.fa_utils import get_flash_attn_version
16
17
from vllm.config import (CompilationLevel, VllmConfig,
                         get_layers_from_vllm_config)
18
19
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
                                          has_kv_transfer_group)
20
from vllm.distributed.parallel_state import get_pp_group, graph_capture
21
22
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
23
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
24
from vllm.model_executor.model_loader import get_model
25
26
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
27
from vllm.multimodal.utils import group_mm_inputs_by_modality
28
from vllm.sampling_params import SamplingType
29
from vllm.sequence import IntermediateTensors
30
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
31
32
                        GiB_bytes, LayerBlockType, LazyLoader, cdiv,
                        check_use_alibi, is_pin_memory_available)
33
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
34
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
35
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
36
37
38
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
                                        KVCacheConfig, KVCacheSpec,
                                        SlidingWindowSpec)
39
40
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
                             ModelRunnerOutput)
41
from vllm.v1.sample.metadata import SamplingMetadata
42
from vllm.v1.sample.rejection_sampler import RejectionSampler
43
from vllm.v1.sample.sampler import Sampler
44
from vllm.v1.spec_decode.eagle import EagleProposer
45
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
46
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
47
from vllm.v1.spec_decode.utils import is_spec_decode_supported
48
from vllm.v1.utils import bind_kv_cache
49
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
50
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
51

52
53
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
                    scatter_mm_placeholders)
54

55
if TYPE_CHECKING:
56
57
    import xgrammar as xgr

58
    from vllm.v1.core.sched.output import SchedulerOutput
59
60
else:
    xgr = LazyLoader("xgr", globals(), "xgrammar")
61
62
63
64

logger = init_logger(__name__)


65
class GPUModelRunner(LoRAModelRunnerMixin):
66
67
68

    def __init__(
        self,
69
        vllm_config: VllmConfig,
70
        device: torch.device,
71
    ):
72
73
74
75
76
77
78
79
80
81
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.prompt_adapter_config = vllm_config.prompt_adapter_config
        self.observability_config = vllm_config.observability_config
82

83
84
85
86
        from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
        set_cpu_offload_max_bytes(
            int(self.cache_config.cpu_offload_gb * 1024**3))

87
88
89
90
        model_config = self.model_config
        cache_config = self.cache_config
        scheduler_config = self.scheduler_config
        parallel_config = self.parallel_config
91
        self.device = device
92
93
94
95
96
97
98
99
        self.pin_memory = is_pin_memory_available()
        self.dtype = self.model_config.dtype
        if cache_config.cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        else:
            self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
                cache_config.cache_dtype]

100
101
        # NOTE(woosuk): sliding_window is None for models with interleaved
        # attention. Use interleaved_sliding_window instead.
102
        self.sliding_window = model_config.get_sliding_window()
103
104
105
106
107
108
        self.interleaved_sliding_window = getattr(
            model_config.hf_text_config, "interleaved_sliding_window", None)
        self.window_size = (self.sliding_window
                            or self.interleaved_sliding_window)

        self.is_multimodal_model = model_config.is_multimodal_model
109
110
111
112
        self.block_size = cache_config.block_size
        self.max_model_len = model_config.max_model_len
        self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
        self.max_num_tokens = scheduler_config.max_num_batched_tokens
113
        self.max_num_reqs = scheduler_config.max_num_seqs
114
115

        # Model-related.
116
117
        self.num_attn_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)
118
119
        self.num_query_heads = model_config.get_num_attention_heads(
            parallel_config)
120
121
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
        self.head_size = model_config.get_head_size()
122
        self.hidden_size = model_config.get_hidden_size()
123
        self.attention_chunk_size = model_config.attention_chunk_size
124

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        self.attn_backend = get_attn_backend(
            self.head_size,
            self.dtype,
            self.kv_cache_dtype,
            self.block_size,
            self.model_config.is_attention_free,
            use_mla=self.model_config.use_mla,
        )
        if self.attn_backend is None:
            error_msg = (
                f"Error with get_att_backend: {self.head_size=}, "
                f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
                f"{self.model_config.is_attention_free=}, "
                f"{self.model_config.use_mla=}")
            logger.error(error_msg)
            raise NotImplementedError(
                "Non-Attention backend is not supported by V1 GPUModelRunner.")

143
144
145
146
147
148
149
150
151
152
        if self.vllm_config.compilation_config.full_cuda_graph:
            attn_backend_name = self.attn_backend.__name__
            flash_attn_version = get_flash_attn_version()
            if attn_backend_name != "FlashAttentionBackend" or \
                flash_attn_version != 3:
                raise ValueError(
                    f"full_cuda_graph is only supported with "
                    f"FA3. Current attention backend is {attn_backend_name}, "
                    f"FlashAttention version is {flash_attn_version}.")

153
        self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
154

155
        # Multi-modal data support
156
        self.mm_registry = MULTIMODAL_REGISTRY
157
        self.uses_mrope = model_config.uses_mrope
158

159
160
161
        encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
            model_config=model_config,
            scheduler_config=scheduler_config,
162
            mm_registry=self.mm_registry,
163
164
165
        )
        self.max_num_encoder_input_tokens = encoder_compute_budget
        self.encoder_cache_size = encoder_cache_size
166

167
168
169
        # Sampler
        self.sampler = Sampler()

170
        # Lazy initializations
171
        # self.model: nn.Module  # Set after load_model
172
        # Initialize in initialize_kv_cache
173
        self.kv_caches: list[torch.Tensor] = []
174
        # self.kv_cache_config: KVCacheConfig
175
        # self.attn_metadata_builder: type[AttentionMetadataBuilder]
176

177
        # req_id -> (input_id -> encoder_output)
178
        self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
179

180
181
        # Set up speculative decoding.
        self.use_spec_decode = False
182
        self.use_aux_hidden_state_outputs = False
183
        if self.speculative_config:
184
185
            self.use_spec_decode = True
            if get_pp_group().is_last_rank:
186
187
                if self.speculative_config.method == "ngram":
                    self.drafter = NgramProposer(self.vllm_config)
188
                elif self.speculative_config.use_eagle():
189
190
                    self.drafter = EagleProposer(self.vllm_config,
                                                 self.device)  # type: ignore
191
192
                    if self.speculative_config.method == "eagle3":
                        self.use_aux_hidden_state_outputs = True
193
194
195
                else:
                    raise ValueError("Unknown speculative decoding method: "
                                     f"{self.speculative_config.method}")
196
                self.rejection_sampler = RejectionSampler()
197

198
        # Request states.
199
        self.requests: dict[str, CachedRequestState] = {}
200
201
        # Persistent batch.
        self.input_batch = InputBatch(
202
            max_num_reqs=self.max_num_reqs,
203
204
            max_model_len=self.max_model_len,
            max_num_blocks_per_req=self.max_num_blocks_per_req,
205
            max_num_batched_tokens=self.max_num_tokens,
206
207
            device=self.device,
            pin_memory=self.pin_memory,
208
            vocab_size=model_config.get_vocab_size(),
209
210
        )

211
        self.use_cuda_graph = (self.vllm_config.compilation_config.level
212
213
214
                               == CompilationLevel.PIECEWISE
                               and not self.model_config.enforce_eager)
        # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
215
216
217
218
        # The convention is different.
        # self.cudagraph_batch_sizes sorts in ascending order.
        # The batch sizes in the config are in descending order.
        self.cudagraph_batch_sizes = list(
219
220
            reversed(
                self.vllm_config.compilation_config.cudagraph_capture_sizes))
221

222
223
224
225
        # Cache the device properties.
        self.device_properties = torch.cuda.get_device_properties(self.device)
        self.num_sms = self.device_properties.multi_processor_count

226
227
228
229
        # Persistent buffers for CUDA graphs.
        self.input_ids = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int32,
                                     device=self.device)
230
231
232
        self.positions = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int64,
                                     device=self.device)
233
234
235
236
237
238
239
240
241
242
        self.query_start_loc = torch.zeros(self.max_num_reqs + 1,
                                           dtype=torch.int32,
                                           device=self.device)
        self.seq_lens = torch.zeros(self.max_num_reqs,
                                    dtype=torch.int32,
                                    device=self.device)
        self.slot_mapping = torch.zeros(self.max_num_tokens,
                                        dtype=torch.int64,
                                        device=self.device)

243
244
        # None in the first PP rank. The rest are set after load_model.
        self.intermediate_tensors: Optional[IntermediateTensors] = None
245
246

        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
247
        if self.uses_mrope:
Roger Wang's avatar
Roger Wang committed
248
249
250
251
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
252
253
254
255
256
257

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
Roger Wang's avatar
Roger Wang committed
258
            self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1),
259
260
                                               dtype=torch.int64,
                                               device=self.device)
Roger Wang's avatar
Roger Wang committed
261
262
263
264
265
            self.mrope_positions_cpu = torch.zeros(
                (3, self.max_num_tokens + 1),
                dtype=torch.int64,
                device="cpu",
                pin_memory=self.pin_memory)
266

267
268
269
        # Only relevant for models using ALiBi (e.g, MPT)
        self.use_alibi = check_use_alibi(model_config)

270
271
272
273
        self.inputs_embeds = torch.zeros(
            (self.max_num_tokens, self.hidden_size),
            dtype=self.dtype,
            device=self.device)
274

275
        # OPTIMIZATION: Cache the tensors rather than creating them every step.
276
        # Keep in int64 to avoid overflow with long context
277
        self.arange_np = np.arange(max(self.max_num_reqs + 1,
278
279
                                       self.max_model_len,
                                       self.max_num_tokens),
280
                                   dtype=np.int64)
281
282
283
        # NOTE(woosuk): These tensors are "stateless", i.e., they are literally
        # a faster version of creating a new tensor every time. Thus, we should
        # not make any assumptions about the values in these tensors.
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
        self.input_ids_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int32,
                                         device="cpu",
                                         pin_memory=self.pin_memory)
        self.input_ids_np = self.input_ids_cpu.numpy()
        self.positions_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int64,
                                         device="cpu",
                                         pin_memory=self.pin_memory)
        self.positions_np = self.positions_cpu.numpy()
        self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
                                               dtype=torch.int32,
                                               device="cpu",
                                               pin_memory=self.pin_memory)
        self.query_start_loc_np = self.query_start_loc_cpu.numpy()
299
300
301
302
303
        self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
                                        dtype=torch.int32,
                                        device="cpu",
                                        pin_memory=self.pin_memory)
        self.seq_lens_np = self.seq_lens_cpu.numpy()
304

305
    def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
306
307
308
309
310
311
        """Update the cached states and the persistent batch with the scheduler
        output.

        The updated states are used by the `_prepare_inputs` function to create
        the input GPU tensors for the model.

312
313
        The SamplingMetadata is updated and copied to the GPU if there is a
        new/resumed/paused/finished request in the batch.
314
315
        """
        # Remove finished requests from the cached states.
316
317
        for req_id in scheduler_output.finished_req_ids:
            self.requests.pop(req_id, None)
318
            self.encoder_cache.pop(req_id, None)
319
320
321
322
323
324
        # Remove the finished requests from the persistent batch.
        # NOTE(woosuk): There could be an edge case where finished_req_ids and
        # scheduled_req_ids overlap. This happens when a request is aborted and
        # then resubmitted with the same ID. In this case, we treat them as two
        # distinct requests - clearing the cached states for the first request
        # and handling the second as a new request.
325
        removed_req_indices: list[int] = []
326
327
328
329
        for req_id in scheduler_output.finished_req_ids:
            req_index = self.input_batch.remove_request(req_id)
            if req_index is not None:
                removed_req_indices.append(req_index)
330
331
332
333
334
335
336
337

        # Free the cached encoder outputs.
        for req_id, input_id in scheduler_output.free_encoder_input_ids:
            encoder_outputs = self.encoder_cache.get(req_id)
            if encoder_outputs is not None:
                encoder_outputs.pop(input_id, None)
                if not encoder_outputs:
                    self.encoder_cache.pop(req_id, None)
338

339
340
341
342
343
344
345
346
347
348
349
350
351
        # Remove the unscheduled requests from the persistent batch.
        # NOTE(woosuk): The unscheduled requests are either preempted requests
        # or running requests that are not scheduled in this step. We remove
        # them from the persistent batch but keep their cached states since
        # they will be scheduled again sometime in the future.
        scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
        cached_req_ids = self.input_batch.req_id_to_index.keys()
        unscheduled_req_ids = cached_req_ids - scheduled_req_ids
        # NOTE(woosuk): The persistent batch optimization assumes that
        # consecutive batches contain mostly the same requests. If batches
        # have low request overlap (e.g., alternating between two distinct
        # sets of requests), this optimization becomes very inefficient.
        for req_id in unscheduled_req_ids:
352
            req_index = self.input_batch.remove_request(req_id)
353
354
            assert req_index is not None
            removed_req_indices.append(req_index)
355

356
        req_ids_to_add: list[str] = []
357
        # Add new requests to the cached states.
358
359
360
        for new_req_data in scheduler_output.scheduled_new_reqs:
            req_id = new_req_data.req_id
            sampling_params = new_req_data.sampling_params
361
            if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
362
363
364
365
366
                generator = torch.Generator(device=self.device)
                generator.manual_seed(sampling_params.seed)
            else:
                generator = None

367
368
            self.requests[req_id] = CachedRequestState(
                req_id=req_id,
369
370
371
                prompt_token_ids=new_req_data.prompt_token_ids,
                mm_inputs=new_req_data.mm_inputs,
                mm_positions=new_req_data.mm_positions,
372
373
                sampling_params=sampling_params,
                generator=generator,
374
375
                block_ids=new_req_data.block_ids,
                num_computed_tokens=new_req_data.num_computed_tokens,
376
                output_token_ids=[],
377
                lora_request=new_req_data.lora_request,
378
            )
379
380

            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
381
            if self.uses_mrope:
382
383
                image_grid_thw = []
                video_grid_thw = []
Roger Wang's avatar
Roger Wang committed
384
                second_per_grid_ts = []
385
386
                audio_feature_lengths = []
                use_audio_in_video = False
387
388
389
390
391
392
393
                for mm_input in self.requests[req_id].mm_inputs:
                    if mm_input.get("image_grid_thw") is not None:
                        image_grid_thw.extend(
                            mm_input["image_grid_thw"].tolist())
                    if mm_input.get("video_grid_thw") is not None:
                        video_grid_thw.extend(
                            mm_input["video_grid_thw"].tolist())
Roger Wang's avatar
Roger Wang committed
394
395
396
                    if mm_input.get("second_per_grid_ts") is not None:
                        second_per_grid_ts.extend(
                            mm_input["second_per_grid_ts"])
397
398
399
400
401
                    if mm_input.get("audio_feature_lengths") is not None:
                        audio_feature_lengths.extend(
                            mm_input["audio_feature_lengths"])
                    if mm_input.get("use_audio_in_video") is True:
                        use_audio_in_video = True
402
403
404
405
406
407
408

                hf_config = self.model_config.hf_config

                self.requests[req_id].mrope_positions, \
                    self.requests[req_id].mrope_position_delta = \
                    MRotaryEmbedding.get_input_positions_tensor(
                        self.requests[req_id].prompt_token_ids,
Roger Wang's avatar
Roger Wang committed
409
                        hf_config=hf_config,
410
411
                        image_grid_thw=image_grid_thw,
                        video_grid_thw=video_grid_thw,
Roger Wang's avatar
Roger Wang committed
412
                        second_per_grid_ts=second_per_grid_ts,
413
414
                        audio_feature_lengths=audio_feature_lengths,
                        use_audio_in_video=use_audio_in_video,
415
416
                    )

417
418
            req_ids_to_add.append(req_id)

419
420
421
        # Update the states of the running/resumed requests.
        for req_data in scheduler_output.scheduled_cached_reqs:
            req_id = req_data.req_id
422
423
            req_state = self.requests[req_id]

424
            # Update the cached states.
425
426
427
428
429
430
431
            num_computed_tokens = req_data.num_computed_tokens
            req_state.num_computed_tokens = num_computed_tokens
            # Add the sampled token(s) from the previous step (if any).
            # This doesn't include "unverified" tokens like spec decode tokens.
            num_new_tokens = (num_computed_tokens +
                              len(req_data.new_token_ids) -
                              req_state.num_tokens)
432
433
434
435
436
437
            if num_new_tokens == 1:
                # Avoid slicing list in most common case.
                req_state.output_token_ids.append(req_data.new_token_ids[-1])
            elif num_new_tokens > 0:
                req_state.output_token_ids.extend(
                    req_data.new_token_ids[-num_new_tokens:])
438
            # Update the block IDs.
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
            if not req_data.resumed_from_preemption:
                # Append the new blocks to the existing block IDs.
                req_state.block_ids.extend(req_data.new_block_ids)
            else:
                # The request is resumed from preemption.
                # Replace the existing block IDs with the new ones.
                req_state.block_ids = req_data.new_block_ids

            req_index = self.input_batch.req_id_to_index.get(req_id)
            if req_index is None:
                # The request is not in the persistent batch.
                # The request was either preempted and resumed later, or was not
                # scheduled in the previous step and needs to be added again.
                req_ids_to_add.append(req_id)
                continue

            # Update the persistent batch.
            self.input_batch.num_computed_tokens_cpu[req_index] = (
457
                num_computed_tokens)
458
459
            self.input_batch.block_table.append_row(req_data.new_block_ids,
                                                    req_index)
460
461
462
463
464
465
            # Add new_token_ids to token_ids_cpu.
            start_token_index = num_computed_tokens
            end_token_index = num_computed_tokens + len(req_data.new_token_ids)
            self.input_batch.token_ids_cpu[
                req_index,
                start_token_index:end_token_index] = req_data.new_token_ids
466
            self.input_batch.num_tokens_no_spec[req_index] = end_token_index
467
468
            # Add spec_token_ids to token_ids_cpu.
            spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
469
                req_id, ())
470
471
472
473
474
475
476
            if spec_token_ids:
                start_index = end_token_index
                end_token_index += len(spec_token_ids)
                self.input_batch.token_ids_cpu[
                    req_index, start_index:end_token_index] = spec_token_ids
            # NOTE(woosuk): `num_tokens` here may include spec decode tokens.
            self.input_batch.num_tokens[req_index] = end_token_index
477

478
479
        # Check if the batch has changed. If not, we can skip copying the
        # sampling metadata from CPU to GPU.
480
481
        batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0

482
483
        # Add the new or resumed requests to the persistent batch.
        # The smaller empty indices are filled first.
484
        removed_req_indices.sort(reverse=True)
485
486
487
488
489
490
491
492
493
494
495
496
497
        for req_id in req_ids_to_add:
            req_state = self.requests[req_id]
            if removed_req_indices:
                # Fill the empty index.
                req_index = removed_req_indices.pop()
            else:
                # Append to the end.
                req_index = None
            self.input_batch.add_request(req_state, req_index)

        # Condense the batched states if there are empty indices.
        if removed_req_indices:
            self.input_batch.condense(removed_req_indices)
498

499
500
501
502
503
504
505
        # Some attention backends (namely MLA) may want to separate requests
        # based on if the attention computation will be compute-bound or
        # memory-bound. This gives them a hook to do that.
        batch_reordered = self.attn_metadata_builder.reorder_batch(
            self.input_batch, scheduler_output)

        if batch_changed or batch_reordered:
506
            self.input_batch.refresh_sampling_metadata()
507

508
    def _prepare_inputs(
509
510
        self,
        scheduler_output: "SchedulerOutput",
511
    ) -> tuple[dict[str, FlashAttentionMetadata], torch.Tensor,
512
               Optional[SpecDecodeMetadata]]:
513
514
515
516
517
518
519
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        assert total_num_scheduled_tokens > 0
        num_reqs = self.input_batch.num_reqs
        assert num_reqs > 0

        # OPTIMIZATION: Start copying the block table first.
        # This way, we can overlap the copy with the following CPU operations.
520
        self.input_batch.block_table.commit(num_reqs)
521
522

        # Get the number of scheduled tokens for each request.
523
524
525
526
        req_ids = self.input_batch.req_ids
        tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
        num_scheduled_tokens = np.array(tokens, dtype=np.int32)
        max_num_scheduled_tokens = max(tokens)
527
528
529

        # Get request indices.
        # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
530
531
        req_indices = np.repeat(self.arange_np[:num_reqs],
                                num_scheduled_tokens)
532
533
534

        # Get batched arange.
        # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
535
536
537
538
539
540
541
542
543
        # Equivalent to but faster than:
        # np.concatenate([np.arange(n) for n in num_scheduled_tokens])
        # Step 1. [2, 5, 3] -> [2, 7, 10]
        cu_num_tokens = np.cumsum(num_scheduled_tokens)
        # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
        cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
                                    num_scheduled_tokens)
        # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
544
545

        # Get positions.
546
        positions_np = self.positions_np[:total_num_scheduled_tokens]
547
548
549
550
        np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
               arange,
               out=positions_np)

551
552
        # Calculate M-RoPE positions.
        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
553
        if self.uses_mrope:
554
555
            self._calc_mrope_positions(scheduler_output)

556
557
558
559
        # Get token indices.
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
        # where M is the max_model_len.
560
561
        token_indices = (positions_np +
                         req_indices * self.input_batch.token_ids_cpu.shape[1])
562

563
564
565
566
        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
        torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
567
                           0,
568
569
                           torch.from_numpy(token_indices),
                           out=self.input_ids_cpu[:total_num_scheduled_tokens])
570
571

        # Calculate the slot mapping.
572
573
574
575
576
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
        # where K is the max_num_blocks_per_req and the block size is 2.
        # NOTE(woosuk): We can't simply use `token_indices // block_size` here
        # because M (max_model_len) is not necessarily divisible by block_size.
577
578
        block_table_indices = (req_indices * self.max_num_blocks_per_req +
                               positions_np // self.block_size)
579
580
        block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
        block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
581
582
583
        block_offsets = positions_np % self.block_size
        np.add(block_numbers * self.block_size,
               block_offsets,
584
585
               out=self.input_batch.block_table.
               slot_mapping_np[:total_num_scheduled_tokens])
586
587

        # Prepare the attention metadata.
588
        self.query_start_loc_np[0] = 0
589
        self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
590

591
592
593
        self.seq_lens_np[:num_reqs] = (
            self.input_batch.num_computed_tokens_cpu[:num_reqs] +
            num_scheduled_tokens)
594
595
596
597

        # Copy the tensors to the GPU.
        self.input_ids[:total_num_scheduled_tokens].copy_(
            self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
598
        if self.uses_mrope:
599
600
601
602
603
604
605
606
607
            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
            self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
                self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
                non_blocking=True)
        else:
            # Common case (1D positions)
            self.positions[:total_num_scheduled_tokens].copy_(
                self.positions_cpu[:total_num_scheduled_tokens],
                non_blocking=True)
608

609
610
611
612
613
614
615
616
617
618
619
620
        self.query_start_loc[:num_reqs + 1].copy_(
            self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
        self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
                                       non_blocking=True)

        # Fill unused with -1. Needed for reshape_and_cache
        self.seq_lens[num_reqs:].fill_(0)
        self.query_start_loc[num_reqs + 1:].fill_(-1)

        query_start_loc = self.query_start_loc[:num_reqs + 1]
        seq_lens = self.seq_lens[:num_reqs]

621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
        common_attn_metadata = CommonAttentionMetadata(
            query_start_loc=query_start_loc, seq_lens=seq_lens)

        attn_metadata: dict[str, FlashAttentionMetadata] = {}
        # Prepare the attention metadata for each KV cache group and make layers
        # in the same group share the same metadata.
        # NOTE(Chen): there is exactly one KV cache group that contains all
        # attetnion layers in the model for now, so the current logic for
        # getting attn_metadata is not related to kv_cache_group information.
        # Will extend this part to support multiple KV cache groups later.
        for kv_cache_group_id, kv_cache_group_spec in enumerate(
                self.kv_cache_config.kv_cache_groups):

            # Prepare for cascade attention if enabled & beneficial.
            common_prefix_len = 0
            if self.cascade_attn_enabled:
                common_prefix_len = self._compute_cascade_attn_prefix_len(
                    num_scheduled_tokens,
                    scheduler_output.num_common_prefix_blocks,
                )
641

642
643
644
645
646
647
648
649
            attn_metadata_i = self.attn_metadata_builder.build(
                num_reqs=num_reqs,
                num_actual_tokens=total_num_scheduled_tokens,
                max_query_len=max_num_scheduled_tokens,
                common_prefix_len=common_prefix_len,
                common_attn_metadata=common_attn_metadata)
            for layer_name in kv_cache_group_spec.layer_names:
                attn_metadata[layer_name] = attn_metadata_i
650

651
652
        use_spec_decode = len(
            scheduler_output.scheduled_spec_decode_tokens) > 0
653
        if not use_spec_decode:
654
655
656
657
658
            # NOTE(woosuk): Due to chunked prefills, the batch may contain
            # partial requests. While we should not sample any token
            # from these partial requests, we do so for simplicity.
            # We will ignore the sampled tokens from the partial requests.
            # TODO: Support prompt logprobs.
659
            logits_indices = query_start_loc[1:] - 1
660
661
662
663
664
665
666
667
668
669
670
671
672
673
            spec_decode_metadata = None
        else:
            # Get the number of draft tokens for each request.
            # Iterate over the dictionary rather than all requests since not all
            # requests have draft tokens.
            num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
            for req_id, draft_token_ids in (
                    scheduler_output.scheduled_spec_decode_tokens.items()):
                req_idx = self.input_batch.req_id_to_index[req_id]
                num_draft_tokens[req_idx] = len(draft_token_ids)

            spec_decode_metadata = self._calc_spec_decode_metadata(
                num_draft_tokens, cu_num_tokens)
            logits_indices = spec_decode_metadata.logits_indices
674

675
676
677
678
        # Hot-Swap lora model
        if self.lora_config:
            self.set_active_loras(self.input_batch, num_scheduled_tokens)

679
        return attn_metadata, logits_indices, spec_decode_metadata
680

681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
    def _compute_cascade_attn_prefix_len(
        self,
        num_scheduled_tokens: np.ndarray,
        num_common_prefix_blocks: int,
    ) -> int:
        """Compute the length of the common prefix for cascade attention.

        NOTE(woosuk): The common prefix length returned by this function
        represents the length used specifically for cascade attention, not the
        actual number of tokens shared between requests. When cascade attention
        is disabled (use_cascade=False), this function returns 0 even if
        requests share common tokens. Additionally, the common prefix length is
        truncated to a multiple of the block size and may be further truncated
        due to implementation details explained below.

        Args:
            num_scheduled_tokens: Number of tokens scheduled per request.
            num_common_prefix_blocks: Number of shared KV cache blocks.

        Returns:
            int: Length of common prefix in tokens.
        """
        common_prefix_len = num_common_prefix_blocks * self.block_size
        if common_prefix_len == 0:
            # Common case.
            return 0

        # NOTE(woosuk): Cascade attention uses two attention kernels: one
        # for the common prefix and the other for the rest. For the first
        # kernel, we concatenate all the query tokens (possibly from
        # different requests) and treat them as if they are from the same
        # request. Then, we use bi-directional attention to process the
        # common prefix in the KV cache. Importantly, this means that the
        # first kernel does not do any masking.

        # Consider the following example:
        # Request 1's input query: [D, E, X]
        # Request 1's kv cache: [A, B, C, D, E, X]
        # Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
        # Request 2's input query: [E, Y]
        # Request 2's kv cache: [A, B, C, D, E, Y]
        # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])

        # If we use [A, B, C, D, E] as the common prefix, then the
        # first kernel will compute the bi-directional attention between
        # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
        # However, this is wrong because D in Request 1 should not attend to
        # E in the common prefix (i.e., we need masking).
        # To avoid this, [A, B, C, D] should be the common prefix.
        # That is, the common prefix should be capped by the minimum
        # num_computed_tokens among the requests, and plus one to include
        # the first token of the query.

        # In practice, we use [A, B, C] as the common prefix, instead of
        # [A, B, C, D] (i.e., the common prefix is capped by the minimum
        # num_computed_tokens, without plus one).
        # This is because of an implementation detail: We want to always
        # use two kernels for cascade attention. Let's imagine:
        # Request 3's input query: [D]
        # Request 3's kv cache: [A, B, C, D]
741
        # Request 3's num_computed_tokens: 3 (i.e., [A, B, C])
742
743
744
745
746
747
748
749
750
751
752
753
        # If we use [A, B, C, D] as the common prefix for Request 1-3,
        # then Request 3 will be processed only by the first kernel,
        # and the second kernel will get an empty input. While this is not
        # a fundamental problem, our current implementation does not support
        # this case.
        num_reqs = len(num_scheduled_tokens)
        common_prefix_len = min(
            common_prefix_len,
            self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
        # common_prefix_len should be a multiple of the block size.
        common_prefix_len = (common_prefix_len // self.block_size *
                             self.block_size)
754
        use_cascade = self.attn_metadata_builder.use_cascade_attention(
755
756
757
758
            common_prefix_len=common_prefix_len,
            query_lens=num_scheduled_tokens,
            num_query_heads=self.num_query_heads,
            num_kv_heads=self.num_kv_heads,
759
            use_alibi=self.use_alibi,
760
            use_sliding_window=self.window_size is not None,
761
762
763
764
            num_sms=self.num_sms,
        )
        return common_prefix_len if use_cascade else 0

765
766
    def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
        mrope_pos_ptr = 0
767
        for index, req_id in enumerate(self.input_batch.req_ids):
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
            req = self.requests[req_id]
            assert req.mrope_positions is not None

            num_computed_tokens = \
                self.input_batch.num_computed_tokens_cpu[index]
            num_scheduled_tokens = \
                scheduler_output.num_scheduled_tokens[req_id]
            num_prompt_tokens = len(req.prompt_token_ids)

            if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
                prompt_part_len = max(0,
                                      num_prompt_tokens - num_computed_tokens)
                completion_part_len = max(
                    0, num_scheduled_tokens - prompt_part_len)
            else:
                prompt_part_len = num_scheduled_tokens
                completion_part_len = 0

            assert num_scheduled_tokens == prompt_part_len + completion_part_len

            if prompt_part_len > 0:
                # prompt's mrope_positions are pre-computed
                dst_start = mrope_pos_ptr
                dst_end = mrope_pos_ptr + prompt_part_len
                src_start = num_computed_tokens
                src_end = num_computed_tokens + prompt_part_len

                self.mrope_positions_cpu[:, dst_start:dst_end] = \
                    req.mrope_positions[:,src_start:src_end]

                mrope_pos_ptr += prompt_part_len

            if completion_part_len > 0:
                # compute completion's mrope_positions on-the-fly
                dst_start = mrope_pos_ptr
                dst_end = mrope_pos_ptr + completion_part_len

                self.mrope_positions_cpu[:, dst_start:dst_end] = \
                    MRotaryEmbedding.get_next_input_positions_tensor(
                        req.mrope_position_delta,
                        context_len=num_computed_tokens +
                        prompt_part_len,
                        seq_len=num_computed_tokens +
                        prompt_part_len +
                        completion_part_len,
                    )

                mrope_pos_ptr += completion_part_len

817
818
    def _calc_spec_decode_metadata(
        self,
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
        num_draft_tokens: np.ndarray,
        cu_num_scheduled_tokens: np.ndarray,
    ) -> SpecDecodeMetadata:
        # Inputs:
        # cu_num_scheduled_tokens:  [  4, 104, 107, 207, 209]
        # num_draft_tokens:         [  3,   0,   2,   0,   1]
        # Outputs:
        # cu_num_draft_tokens:      [  3,   3,   5,   5,   6]
        # logits_indices:           [  0,   1,   2,   3, 103, 104, 105, 106,
        #                            206, 207, 208]
        # target_logits_indices:    [  0,   1,   2,   5,   6,   9]
        # bonus_logits_indices:     [  3,   4,   7,   8,  10]

        # Compute the logits indices.
        # [4, 1, 3, 1, 2]
        num_sampled_tokens = num_draft_tokens + 1
        # Step 1. [4, 5, 8, 9, 11]
        cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
        total_num_sampled_tokens = cu_num_sampled_tokens[-1]
        # Step 2. [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
        cumsums_offsets = np.repeat(cu_num_sampled_tokens - num_sampled_tokens,
                                    num_sampled_tokens)
        # Step 3. [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
        arange = self.arange_np[:total_num_sampled_tokens] - cumsums_offsets
        # Step 4. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
        logits_indices = np.repeat(
            cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens)
        # Step 5. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
        logits_indices += arange

        # Compute the bonus logits indices.
        bonus_logits_indices = cu_num_sampled_tokens - 1

        # Compute the draft logits indices.
        # [3, 3, 5, 5, 6]
        cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
        total_num_draft_tokens = cu_num_draft_tokens[-1]
        # [0, 0, 0, 3, 3, 5]
        cumsums_offsets = np.repeat(cu_num_draft_tokens - num_draft_tokens,
                                    num_draft_tokens)
        # [0, 1, 2, 0, 1, 0]
        arange = self.arange_np[:total_num_draft_tokens] - cumsums_offsets
        # [0, 0, 0, 5, 5, 9]
        target_logits_indices = np.repeat(
            cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens)
        # [0, 1, 2, 5, 6, 9]
        target_logits_indices += arange

        # TODO: Optimize the CPU -> GPU copy.
        cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
            self.device, non_blocking=True)
        logits_indices = torch.from_numpy(logits_indices).to(self.device,
                                                             non_blocking=True)
        target_logits_indices = torch.from_numpy(target_logits_indices).to(
            self.device, non_blocking=True)
        bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to(
875
876
            self.device, non_blocking=True)

877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
        # Compute the draft token ids.
        # draft_token_indices:      [  1,   2,   3, 105, 106, 208]
        draft_token_ids = self.input_ids[logits_indices]
        draft_token_ids = draft_token_ids[target_logits_indices + 1]

        metadata = SpecDecodeMetadata(
            draft_token_ids=draft_token_ids,
            num_draft_tokens=num_draft_tokens.tolist(),
            cu_num_draft_tokens=cu_num_draft_tokens,
            target_logits_indices=target_logits_indices,
            bonus_logits_indices=bonus_logits_indices,
            logits_indices=logits_indices,
        )
        return metadata

892
    def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
893
894
895
896
897
        scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
        if not scheduled_encoder_inputs:
            return

        # Batch the multi-modal inputs.
898
899
        mm_inputs = list[MultiModalKwargs]()
        req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
900
901
        for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
            req_state = self.requests[req_id]
902
903
904
905
906

            for mm_input_id in encoder_input_ids:
                mm_inputs.append(req_state.mm_inputs[mm_input_id])
                req_ids_pos.append(
                    (req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932

        # Batch mm inputs as much as we can: if a request in the batch has
        # multiple modalities or a different modality than the previous one,
        # we process it separately to preserve item order.
        # FIXME(ywang96): This is a hacky way to deal with multiple modalities
        # in the same batch while still being able to benefit from batching
        # multimodal inputs. The proper solution should be reordering the
        # encoder outputs.
        grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)

        encoder_outputs = []
        for grouped_mm_inputs in grouped_mm_inputs_list:
            batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
            batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
                                                           device=self.device)

            # Run the encoder.
            # `curr_group_outputs` is either of the following:
            # 1. A tensor of shape (num_items, feature_size, hidden_size)
            # in case feature_size is fixed across all multimodal items.
            # 2. A list or tuple (length: num_items) of tensors, each of shape
            # (feature_size, hidden_size) in case the feature size is dynamic
            # depending on the input multimodal items.
            curr_group_outputs = self.model.get_multimodal_embeddings(
                **batched_mm_inputs)

933
934
935
936
937
            sanity_check_mm_encoder_outputs(
                curr_group_outputs,
                expected_num_items=len(grouped_mm_inputs),
            )

938
939
            for output in curr_group_outputs:
                encoder_outputs.append(output)
940
941

        # Cache the encoder outputs.
942
943
944
945
        for (req_id, input_id, pos_info), output in zip(
                req_ids_pos,
                encoder_outputs,
        ):
946
947
948
            if req_id not in self.encoder_cache:
                self.encoder_cache[req_id] = {}

949
950
951
952
953
954
            self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
                output,
                is_embed=pos_info.is_embed,
            )

    def _gather_mm_embeddings(
955
956
        self,
        scheduler_output: "SchedulerOutput",
957
    ) -> list[torch.Tensor]:
958
        mm_embeds: list[torch.Tensor] = []
959
        for req_id in self.input_batch.req_ids:
960
961
962
963
964
965
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
                req_id]
            req_state = self.requests[req_id]
            num_computed_tokens = req_state.num_computed_tokens
            mm_positions = req_state.mm_positions
            for i, pos_info in enumerate(mm_positions):
966
967
                start_pos = pos_info.offset
                num_encoder_tokens = pos_info.length
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988

                # The encoder output is needed if the two ranges overlap:
                # [num_computed_tokens,
                #  num_computed_tokens + num_scheduled_tokens) and
                # [start_pos, start_pos + num_encoder_tokens)
                if start_pos >= num_computed_tokens + num_scheduled_tokens:
                    # The encoder output is not needed in this step.
                    break
                if start_pos + num_encoder_tokens <= num_computed_tokens:
                    # The encoder output is already processed and stored
                    # in the decoder's KV cache.
                    continue

                start_idx = max(num_computed_tokens - start_pos, 0)
                end_idx = min(
                    num_computed_tokens - start_pos + num_scheduled_tokens,
                    num_encoder_tokens)
                assert start_idx < end_idx
                assert req_id in self.encoder_cache
                assert i in self.encoder_cache[req_id]
                encoder_output = self.encoder_cache[req_id][i]
989
990
991
992
993
994
995
996
997
998

                if (is_embed := pos_info.is_embed) is not None:
                    is_embed = is_embed[start_idx:end_idx]

                mm_embeds_item = gather_mm_placeholders(
                    encoder_output[start_idx:end_idx],
                    is_embed=is_embed,
                )
                mm_embeds.append(mm_embeds_item)
        return mm_embeds
999

1000
1001
1002
    def get_model(self) -> nn.Module:
        return self.model

1003
1004
1005
1006
1007
1008
1009
1010
1011
    def apply_grammar_bitmask(
        self,
        scheduler_output: "SchedulerOutput",
        logits: torch.Tensor,
    ):
        grammar_bitmask = scheduler_output.grammar_bitmask
        if grammar_bitmask is None:
            return

1012
1013
1014
1015
1016
1017
1018
1019
1020
        # We receive the structured output bitmask from the scheduler,
        # compacted to contain bitmasks only for structured output requests.
        # The order of the requests in the bitmask is not guaranteed to be the
        # same as the order of the requests in the gpu runner's batch. We need
        # to sort the bitmask to match the order of the requests used here.

        # Get the batch indices of the structured output requests.
        # Keep track of the number of speculative tokens scheduled for every
        # request in the batch, as the logit indices are offset by this amount.
1021
        struct_out_req_batch_indices: dict[str, int] = {}
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
        cumulative_offset = 0
        seq = sorted(self.input_batch.req_id_to_index.items(),
                     key=lambda x: x[1])
        for req_id, batch_index in seq:
            logit_index = batch_index + cumulative_offset
            cumulative_offset += len(
                scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
            if req_id in scheduler_output.structured_output_request_ids:
                struct_out_req_batch_indices[req_id] = logit_index

        out_indices = []

        # Reorder the bitmask to match the order of the requests in the batch.
        sorted_bitmask = np.zeros_like(grammar_bitmask,
                                       shape=(logits.shape[0],
                                              grammar_bitmask.shape[1]))
        cumulative_index = 0
        seq = sorted(scheduler_output.structured_output_request_ids.items(),
                     key=lambda x: x[1])
        for req_id, _ in seq:
            logit_index = struct_out_req_batch_indices[req_id]
            num_spec_tokens = len(
                scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
            for i in range(1 + num_spec_tokens):
                sorted_bitmask[logit_index + i] = \
                    grammar_bitmask[cumulative_index + i]
                out_indices.append(logit_index + i)
            cumulative_index += 1 + num_spec_tokens
        grammar_bitmask = sorted_bitmask
1051

1052
1053
        # Serialization of np.ndarray is much more efficient than a tensor,
        # so we receive it in that format.
1054
1055
1056
1057
1058
        grammar_bitmask = torch.from_numpy(grammar_bitmask)

        xgr.apply_token_bitmask_inplace(
            logits,
            grammar_bitmask.to(self.device, non_blocking=True),
1059
            indices=out_indices,
1060
1061
        )

1062
1063
1064
1065
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
1066
        intermediate_tensors: Optional[IntermediateTensors] = None,
1067
    ) -> Union[ModelRunnerOutput, IntermediateTensors]:
1068
1069
1070
1071
1072
        # Update KVConnector with the KVConnector metadata forward().
        if has_kv_transfer_group():
            get_kv_transfer_group().bind_connector_metadata(
                scheduler_output.kv_connector_metadata)

1073
        self._update_states(scheduler_output)
1074
        if not scheduler_output.total_num_scheduled_tokens:
1075
            # Return empty ModelRunnerOutput if there's no work to do.
1076
            return EMPTY_MODEL_RUNNER_OUTPUT
1077
1078

        # Prepare the decoder inputs.
1079
1080
        attn_metadata, logits_indices, spec_decode_metadata = (
            self._prepare_inputs(scheduler_output))
1081
1082
1083
1084
1085
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        if (self.use_cuda_graph
                and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
            # Use piecewise CUDA graphs.
            # Add padding to the batch size.
1086
            num_input_tokens = self.vllm_config.pad_for_cudagraph(
1087
1088
1089
                num_scheduled_tokens)
        else:
            # Eager mode.
1090
1091
1092
1093
1094
1095
1096
1097
1098
            # Pad tokens to multiple of tensor_parallel_size when
            # enabled collective fusion for SP
            tp_size = self.vllm_config.parallel_config.tensor_parallel_size
            if self.vllm_config.compilation_config.pass_config. \
                enable_sequence_parallelism and tp_size > 1:
                from vllm.utils import round_up
                num_input_tokens = round_up(num_scheduled_tokens, tp_size)
            else:
                num_input_tokens = num_scheduled_tokens
1099

1100
1101
1102
1103
1104
1105
1106
1107
1108
        # _prepare_inputs may reorder the batch, so we must gather multi
        # modal outputs after that to ensure the correct order
        if self.is_multimodal_model:
            # Run the multimodal encoder if any.
            self._execute_mm_encoder(scheduler_output)
            mm_embeds = self._gather_mm_embeddings(scheduler_output)
        else:
            mm_embeds = []

1109
1110
1111
1112
1113
        if self.is_multimodal_model:
            # NOTE(woosuk): To unify token ids and soft tokens (vision
            # embeddings), we always use embeddings (rather than token ids)
            # as input to the multimodal model, even when the input is text.
            input_ids = self.input_ids[:num_scheduled_tokens]
1114
            if mm_embeds:
1115
                inputs_embeds = self.model.get_input_embeddings(
1116
                    input_ids, mm_embeds)
1117
1118
1119
1120
1121
1122
            else:
                inputs_embeds = self.model.get_input_embeddings(input_ids)
            # TODO(woosuk): Avoid the copy. Optimize.
            self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
            input_ids = None
1123
        else:
1124
1125
1126
1127
1128
1129
            # For text-only models, we use token ids as input.
            # While it is possible to use embeddings as input just like the
            # multimodal models, it is not desirable for performance since
            # then the embedding layer is not included in the CUDA graph.
            input_ids = self.input_ids[:num_input_tokens]
            inputs_embeds = None
1130
1131
1132
1133
        if self.uses_mrope:
            positions = self.mrope_positions[:, :num_input_tokens]
        else:
            positions = self.positions[:num_input_tokens]
1134

1135
1136
1137
        if get_pp_group().is_first_rank:
            intermediate_tensors = None
        else:
1138
1139
1140
1141
1142
            assert intermediate_tensors is not None
            assert self.intermediate_tensors is not None
            for k, v in intermediate_tensors.items():
                self.intermediate_tensors[k][:num_input_tokens].copy_(
                    v[:num_input_tokens], non_blocking=True)
1143
1144
1145
1146
1147
            intermediate_tensors = IntermediateTensors({
                k: v[:num_input_tokens]
                for k, v in self.intermediate_tensors.items()
            })

1148
1149
        # Run the decoder.
        # Use persistent buffers for CUDA graphs.
1150
1151
1152
        with set_forward_context(attn_metadata,
                                 self.vllm_config,
                                 num_tokens=num_input_tokens):
1153
            output = self.model(
1154
                input_ids=input_ids,
1155
                positions=positions,
1156
                intermediate_tensors=intermediate_tensors,
1157
                inputs_embeds=inputs_embeds,
1158
            )
1159
1160
1161
1162
1163
1164

        if self.use_aux_hidden_state_outputs:
            hidden_states, aux_hidden_states = output
        else:
            hidden_states = output

1165
        if not get_pp_group().is_last_rank:
1166
            # For mid-pipeline stages, return the hidden states.
1167
            return hidden_states
1168

1169
1170
        sample_hidden_states = hidden_states[logits_indices]
        logits = self.model.compute_logits(sample_hidden_states, None)
1171

1172
1173
1174
1175
        # Apply structured output bitmasks if present
        if scheduler_output.grammar_bitmask is not None:
            self.apply_grammar_bitmask(scheduler_output, logits)

1176
        # Sample the next token and get logprobs if needed.
1177
        sampling_metadata = self.input_batch.sampling_metadata
1178
        if spec_decode_metadata is None:
1179
            sampler_output = self.sampler(
1180
1181
1182
1183
                logits=logits,
                sampling_metadata=sampling_metadata,
            )
        else:
1184
1185
1186
1187
            # When indexing with a tensor (bonus_logits_indices), PyTorch
            # creates a new tensor with separate storage from the original
            # logits tensor. This means any in-place operations on bonus_logits
            # won't affect the original logits tensor.
1188
            bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
1189
            sampler_output = self.sampler(
1190
                logits=bonus_logits,
1191
1192
1193
                sampling_metadata=sampling_metadata,
            )
            bonus_token_ids = sampler_output.sampled_token_ids
1194

1195
1196
1197
            # Just like `bonus_logits`, `target_logits` is a new tensor with
            # separate storage from the original `logits` tensor. Therefore,
            # it is safe to update `target_logits` in place.
1198
            target_logits = logits[spec_decode_metadata.target_logits_indices]
1199
            output_token_ids = self.rejection_sampler(
1200
                spec_decode_metadata,
1201
                None,  # draft_probs
1202
                target_logits,
1203
                bonus_token_ids,
1204
1205
                sampling_metadata,
            )
1206
            sampler_output.sampled_token_ids = output_token_ids
1207
1208
1209

        # TODO(woosuk): The following loop can be slow since it iterates over
        # the requests one by one. Optimize.
1210
1211
        discard_sampled_tokens_req_indices = []
        for i, req_id in enumerate(self.input_batch.req_ids):
1212
1213
1214
            req_state = self.requests[req_id]
            seq_len = (req_state.num_computed_tokens +
                       scheduler_output.num_scheduled_tokens[req_id])
1215
            if seq_len < req_state.num_tokens:
1216
                # Ignore the sampled token for partial prefills.
1217
                # Rewind the generator state as if the token was not sampled.
1218
                # This relies on cuda-specific torch-internal impl details
1219
1220
1221
1222
1223
1224
                generator = self.input_batch.generators.get(i)
                if generator is not None:
                    generator.set_offset(generator.get_offset() - 4)
                # Record the index of the request that should not be sampled,
                # so that we could clear the sampled tokens before returning.
                discard_sampled_tokens_req_indices.append(i)
1225

1226
1227
        # NOTE: GPU -> CPU Sync happens here.
        # Move as many CPU operations as possible before this sync point.
1228
1229
1230
1231
1232
1233
        logprobs_tensors = sampler_output.logprobs_tensors
        logprobs_lists = logprobs_tensors.tolists() \
            if logprobs_tensors is not None else None

        # Compute prompt logprobs if needed.
        prompt_logprobs_dict = self._get_prompt_logprobs_dict(
1234
            hidden_states[:num_scheduled_tokens],
1235
1236
1237
            scheduler_output,
        )

1238
        # Get the valid generated tokens.
1239
1240
1241
        sampled_token_ids = sampler_output.sampled_token_ids
        max_gen_len = sampled_token_ids.shape[-1]
        if max_gen_len == 1:
1242
            # No spec decode tokens.
1243
1244
            valid_sampled_token_ids = sampled_token_ids.tolist()
        else:
1245
            # Includes spec decode tokens.
1246
            valid_sampled_token_ids = self.rejection_sampler.parse_output(
1247
1248
1249
                sampled_token_ids,
                self.input_batch.vocab_size,
            )
1250
1251
1252
        # Mask out the sampled tokens that should not be sampled.
        for i in discard_sampled_tokens_req_indices:
            valid_sampled_token_ids[i].clear()
1253

1254
        if not self.use_spec_decode:
1255
            # Speculative decoding is not enabled.
1256
            spec_token_ids = None
1257
1258
        elif self.speculative_config.method == "ngram":
            assert isinstance(self.drafter, NgramProposer)
1259
            spec_token_ids = self.generate_draft_token_ids(
1260
                valid_sampled_token_ids, sampling_metadata)
1261
        elif self.speculative_config.use_eagle():
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
            assert isinstance(self.drafter, EagleProposer)
            # TODO(woosuk): Refactor the loop.
            next_token_ids: list[int] = []
            for i, token_ids in enumerate(valid_sampled_token_ids):
                if token_ids:
                    # Common case.
                    next_token_id = token_ids[-1]
                else:
                    # Partial prefill (rare case).
                    # Get the next token id from the request state.
                    req_id = self.input_batch.req_ids[i]
                    req_state = self.requests[req_id]
                    seq_len = (req_state.num_computed_tokens +
                               scheduler_output.num_scheduled_tokens[req_id])
                    next_token_id = req_state.get_token_id(seq_len)
                next_token_ids.append(next_token_id)
            next_token_ids = torch.tensor(next_token_ids,
                                          dtype=torch.int32,
                                          device=self.device)
1281
            eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
1282
1283
1284
1285

            if spec_decode_metadata is None:
                # input_ids can be None for multimodal models.
                target_token_ids = self.input_ids[:num_scheduled_tokens]
1286
                target_positions = positions[:num_scheduled_tokens]
1287
                if self.use_aux_hidden_state_outputs:
1288
1289
1290
                    target_hidden_states = torch.cat(
                        [h[:num_scheduled_tokens] for h in aux_hidden_states],
                        dim=-1)
1291
1292
                else:
                    target_hidden_states = hidden_states[:num_scheduled_tokens]
1293
1294
                target_slot_mapping = eagle_attn_metadata.slot_mapping
                cu_num_tokens = eagle_attn_metadata.query_start_loc
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
            else:
                # TODO(woosuk): Refactor this.
                num_draft_tokens = spec_decode_metadata.num_draft_tokens
                num_rejected_tokens = [
                    n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
                    for i, n in enumerate(num_draft_tokens)
                ]
                num_rejected_tokens = torch.tensor(
                    num_rejected_tokens,
                    dtype=torch.int32,
                    device=self.device,
                )
                cu_num_tokens, token_indices = self.drafter.prepare_inputs(
1308
                    eagle_attn_metadata.query_start_loc,
1309
1310
1311
1312
                    num_rejected_tokens,
                )
                target_token_ids = self.input_ids[token_indices]
                target_positions = positions[token_indices]
1313
                if self.use_aux_hidden_state_outputs:
1314
1315
                    target_hidden_states = torch.cat(
                        [h[token_indices] for h in aux_hidden_states], dim=-1)
1316
1317
                else:
                    target_hidden_states = hidden_states[token_indices]
1318
1319
                target_slot_mapping = eagle_attn_metadata.slot_mapping[
                    token_indices]
1320

1321
            draft_token_ids = self.drafter.propose(
1322
1323
1324
1325
1326
1327
                target_token_ids=target_token_ids,
                target_positions=target_positions,
                target_hidden_states=target_hidden_states,
                target_slot_mapping=target_slot_mapping,
                next_token_ids=next_token_ids,
                cu_num_tokens=cu_num_tokens,
1328
                block_table=eagle_attn_metadata.block_table,
1329
1330
1331
                sampling_metadata=sampling_metadata,
            )
            spec_token_ids = draft_token_ids.tolist()
1332

1333
1334
1335
1336
        # Clear KVConnector state after all KVs are generated.
        if has_kv_transfer_group():
            get_kv_transfer_group().clear_connector_metadata()

1337
        return ModelRunnerOutput(
1338
            req_ids=self.input_batch.req_ids,
1339
            req_id_to_index=self.input_batch.req_id_to_index,
1340
            sampled_token_ids=valid_sampled_token_ids,
1341
            spec_token_ids=spec_token_ids,
1342
1343
            logprobs=logprobs_lists,
            prompt_logprobs_dict=prompt_logprobs_dict,
1344
1345
        )

1346
1347
    def generate_draft_token_ids(
        self,
1348
        sampled_token_ids: list[list[int]],
1349
        sampling_metadata: SamplingMetadata,
1350
    ) -> list[list[int]]:
1351
        # TODO(woosuk): Optimize.
1352
        draft_token_ids: list[list[int]] = []
1353
1354
1355
        for i, sampled_ids in enumerate(sampled_token_ids):
            num_sampled_ids = len(sampled_ids)
            if not num_sampled_ids:
1356
1357
1358
1359
                # Skip speculative decoding.
                draft_token_ids.append([])
                continue

1360
1361
            # Skip requests that require sampling parameters that are not
            # supported with speculative decoding.
1362
1363
1364
1365
1366
            req_id = self.input_batch.req_ids[i]
            if not is_spec_decode_supported(req_id, self.input_batch):
                draft_token_ids.append([])
                continue

1367
1368
            # Add sampled_token_ids to token_ids_cpu.
            start_idx = self.input_batch.num_tokens_no_spec[i]
1369
            end_idx = start_idx + num_sampled_ids
1370
1371
1372
1373
1374
            if end_idx >= self.max_model_len:
                # Skip requests that have already reached the max model length.
                draft_token_ids.append([])
                continue

1375
            self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
1376
            drafter_output = self.drafter.propose(
1377
                self.input_batch.token_ids_cpu[i, :end_idx])
1378
1379
1380
1381
1382
1383
            if drafter_output is None or len(drafter_output) == 0:
                draft_token_ids.append([])
            else:
                draft_token_ids.append(drafter_output.tolist())
        return draft_token_ids

1384
1385
1386
    def load_model(self) -> None:
        logger.info("Starting to load model %s...", self.model_config.model)
        with DeviceMemoryProfiler() as m:  # noqa: SIM117
1387
            time_before_load = time.perf_counter()
Joe Runde's avatar
Joe Runde committed
1388
            self.model = get_model(vllm_config=self.vllm_config)
1389
1390
1391
1392
1393
1394
            if self.lora_config:
                self.model = self.load_lora_model(self.model,
                                                  self.model_config,
                                                  self.scheduler_config,
                                                  self.lora_config,
                                                  self.device)
1395
1396
1397
            if hasattr(self, "drafter"):
                logger.info("Loading drafter model...")
                self.drafter.load_model(self.model)
1398
1399
1400
            if self.use_aux_hidden_state_outputs:
                self.model.set_aux_hidden_state_layers(
                    self.model.get_eagle3_aux_hidden_state_layers())
1401
            time_after_load = time.perf_counter()
1402
        self.model_memory_usage = m.consumed_memory
1403
1404
        logger.info("Model loading took %.4f GiB and %.6f seconds",
                    self.model_memory_usage / GiB_bytes,
1405
                    time_after_load - time_before_load)
1406

1407
1408
1409
1410
    def _get_prompt_logprobs_dict(
        self,
        hidden_states: torch.Tensor,
        scheduler_output: "SchedulerOutput",
1411
    ) -> dict[str, Optional[LogprobsTensors]]:
1412
1413
1414
1415
        num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
        if not num_prompt_logprobs_dict:
            return {}

1416
        in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
1417
        prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431

        # Since prompt logprobs are a rare feature, prioritize simple,
        # maintainable loop over optimal performance.
        completed_prefill_reqs = []
        for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():

            num_tokens = scheduler_output.num_scheduled_tokens[req_id]

            # Get metadata for this request.
            request = self.requests[req_id]
            num_prompt_tokens = len(request.prompt_token_ids)
            prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
                self.device, non_blocking=True)

1432
1433
1434
1435
1436
1437
1438
1439
1440
            # Set up target LogprobsTensors object.
            logprobs_tensors = in_progress_dict.get(req_id)
            if not logprobs_tensors:
                # Create empty logprobs CPU tensors for the entire prompt.
                # If chunked, we'll copy in slice by slice.
                logprobs_tensors = LogprobsTensors.empty_cpu(
                    num_prompt_tokens - 1, num_prompt_logprobs + 1)
                in_progress_dict[req_id] = logprobs_tensors

1441
            # Determine number of logits to retrieve.
1442
1443
            start_idx = request.num_computed_tokens
            start_tok = start_idx + 1
1444
            num_remaining_tokens = num_prompt_tokens - start_tok
1445
            if num_tokens <= num_remaining_tokens:
1446
                # This is a chunk, more tokens remain.
1447
1448
1449
                # In the == case, there are no more prompt logprobs to produce
                # but we want to defer returning them to the next step where we
                # have new generated tokens to return.
1450
1451
1452
1453
1454
                num_logits = num_tokens
            else:
                # This is the last chunk of prompt tokens to return.
                num_logits = num_remaining_tokens
                completed_prefill_reqs.append(req_id)
1455
1456
1457
1458
1459
1460
1461
                prompt_logprobs_dict[req_id] = logprobs_tensors

            if num_logits <= 0:
                # This can happen for the final chunk if we prefilled exactly
                # (num_prompt_tokens - 1) tokens for this request in the prior
                # step. There are no more prompt logprobs to produce.
                continue
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476

            # Get the logits corresponding to this req's prompt tokens.
            # If this is a partial request (i.e. chunked prefill),
            # then there is prompt logprob generated for each index.
            req_idx = self.input_batch.req_id_to_index[req_id]
            offset = self.query_start_loc_np[req_idx].item()
            prompt_hidden_states = hidden_states[offset:offset + num_logits]
            logits = self.model.compute_logits(prompt_hidden_states, None)

            # Get the "target" tokens for each index. For prompt at index i,
            # the token at prompt index i+1 is the "sampled" token we want
            # to gather the logprob for.
            tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]

            # Compute prompt logprobs.
1477
1478
            logprobs = self.sampler.compute_logprobs(logits)
            token_ids, logprobs, ranks = self.sampler.gather_logprobs(
1479
1480
1481
                logprobs, num_prompt_logprobs, tgt_token_ids)

            # Transfer GPU->CPU async.
1482
1483
1484
1485
1486
1487
1488
            chunk_slice = slice(start_idx, start_idx + num_logits)
            logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
                token_ids, non_blocking=True)
            logprobs_tensors.logprobs[chunk_slice].copy_(logprobs,
                                                         non_blocking=True)
            logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
                ranks, non_blocking=True)
1489
1490
1491
1492
1493

        # Remove requests that have completed prefill from the batch
        # num_prompt_logprobs_dict.
        for req_id in completed_prefill_reqs:
            del num_prompt_logprobs_dict[req_id]
1494
            del in_progress_dict[req_id]
1495
1496

        # Must synchronize the non-blocking GPU->CPU transfers.
1497
1498
        if prompt_logprobs_dict:
            torch.cuda.synchronize()
1499
1500
1501

        return prompt_logprobs_dict

1502
1503
1504
1505
    @torch.inference_mode()
    def _dummy_run(
        self,
        num_tokens: int,
1506
        skip_attn: bool = True,
1507
    ) -> torch.Tensor:
1508

1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
        # Set num_scheduled_tokens based on num_tokens and max_num_seqs
        # for dummy run with LoRA so that the num_reqs collectively
        # has num_tokens in total.
        assert num_tokens <= self.scheduler_config.max_num_batched_tokens
        max_num_reqs = self.scheduler_config.max_num_seqs
        num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens
        min_tokens_per_req = num_tokens // num_reqs
        num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
        num_scheduled_tokens_list[-1] += num_tokens % num_reqs
        assert sum(num_scheduled_tokens_list) == num_tokens
        assert len(num_scheduled_tokens_list) == num_reqs
        num_scheduled_tokens = np.array(num_scheduled_tokens_list,
                                        dtype=np.int32)
1522

1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
        if skip_attn:
            attn_metadata = None
        else:
            query_start_loc = self.query_start_loc[:num_reqs + 1]
            seq_lens = self.seq_lens[:num_reqs]

            common_attn_metadata = CommonAttentionMetadata(
                query_start_loc=query_start_loc, seq_lens=seq_lens)

            attn_metadata = self.attn_metadata_builder.build(
                num_reqs=num_tokens,
                num_actual_tokens=num_tokens,
                max_query_len=num_tokens,
                common_prefix_len=0,
                common_attn_metadata=common_attn_metadata,
            )

1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
        with self.maybe_dummy_run_with_lora(self.lora_config,
                                            num_scheduled_tokens):
            model = self.model
            if self.is_multimodal_model:
                input_ids = None
                inputs_embeds = self.inputs_embeds[:num_tokens]
            else:
                input_ids = self.input_ids[:num_tokens]
                inputs_embeds = None
            if self.uses_mrope:
                positions = self.mrope_positions[:, :num_tokens]
            else:
                positions = self.positions[:num_tokens]

            if get_pp_group().is_first_rank:
                intermediate_tensors = None
            else:
                if self.intermediate_tensors is None:
                    self.intermediate_tensors = (
                        self.model.make_empty_intermediate_tensors(
                            batch_size=self.max_num_tokens,
                            dtype=self.model_config.dtype,
                            device=self.device))
                intermediate_tensors = IntermediateTensors({
                    k: v[:num_tokens]
                    for k, v in self.intermediate_tensors.items()
                })

1568
            with set_forward_context(attn_metadata,
1569
1570
                                     self.vllm_config,
                                     num_tokens=num_tokens):
1571
                outputs = model(
1572
1573
1574
1575
1576
                    input_ids=input_ids,
                    positions=positions,
                    intermediate_tensors=intermediate_tensors,
                    inputs_embeds=inputs_embeds,
                )
1577
1578
1579
1580
            if self.use_aux_hidden_state_outputs:
                hidden_states, _ = outputs
            else:
                hidden_states = outputs
1581

1582
1583
1584
1585
1586
            if self.use_spec_decode and \
                self.speculative_config.method in ('eagle', 'eagle3'):
                assert isinstance(self.drafter, EagleProposer)
                self.drafter.dummy_run(num_tokens)

1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
        logit_indices = np.cumsum(num_scheduled_tokens) - 1
        return hidden_states[logit_indices]

    @torch.inference_mode()
    def _dummy_sampler_run(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:

        logits = self.model.compute_logits(hidden_states, None)
        num_reqs = logits.size(0)

        dummy_tensors = lambda v: torch.full(
            (num_reqs, ), v, device=self.device)

        dummy_metadata = SamplingMetadata(
            temperature=dummy_tensors(0.5),
            all_greedy=False,
            all_random=False,
            top_p=dummy_tensors(0.9),
            top_k=dummy_tensors(logits.size(1) - 1),
            min_p=None,
            generators={},
            max_num_logprobs=None,
            no_penalties=True,
            prompt_token_ids=None,
            frequency_penalties=dummy_tensors(0.1),
            presence_penalties=dummy_tensors(0.1),
            repetition_penalties=dummy_tensors(0.1),
            output_token_ids=[[] for _ in range(num_reqs)],
            min_tokens={},
            logit_bias=[None for _ in range(num_reqs)],
            allowed_token_ids_mask=None,
            bad_words_token_ids={},
        )
1622
        try:
1623
1624
            sampler_output = self.sampler(logits=logits,
                                          sampling_metadata=dummy_metadata)
1625
1626
1627
1628
1629
1630
1631
1632
1633
        except RuntimeError as e:
            if 'out of memory' in str(e):
                raise RuntimeError(
                    "CUDA out of memory occurred when warming up sampler with "
                    f"{num_reqs} dummy requests. Please try lowering "
                    "`max_num_seqs` or `gpu_memory_utilization` when "
                    "initializing the engine.") from e
            else:
                raise e
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
        if self.use_spec_decode:
            draft_token_ids = [[0] for _ in range(num_reqs)]
            dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy(
                draft_token_ids, self.device)

            num_tokens = sum(len(ids) for ids in draft_token_ids)
            # draft_probs = torch.randn(
            #     num_tokens, logits.shape[-1], device=self.device,
            #     dtype=logits.dtype)
            draft_probs = None
            target_logits = torch.randn(num_tokens,
                                        logits.shape[-1],
                                        device=self.device,
                                        dtype=logits.dtype)
            # NOTE(woosuk): Here, we should use int32 because the sampler uses
            # int32 for bonus_token_ids. If the dtype mismatches, re-compilation
            # will occur at runtime.
            bonus_token_ids = torch.zeros(num_reqs,
                                          device=self.device,
                                          dtype=torch.int32)
            self.rejection_sampler(
                dummy_spec_decode_metadata,
                draft_probs,
                target_logits,
                bonus_token_ids,
                dummy_metadata,
            )
1661
        return sampler_output
1662
1663

    def profile_run(self) -> None:
1664
        # Profile with multimodal encoder & encoder cache.
1665
1666
1667
        # TODO: handle encoder-decoder models once we support them.
        if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
                and self.encoder_cache_size > 0):
1668

1669
            # NOTE: Currently model is profiled with a single non-text
1670
1671
            # modality with the max possible input tokens even when
            # it supports multiple.
1672
1673
            max_tokens_by_modality_dict = self.mm_registry \
                .get_max_tokens_per_item_by_nonzero_modality(self.model_config)
1674
1675
1676
1677
            dummy_data_modality, max_tokens_per_mm_item = max(
                max_tokens_by_modality_dict.items(), key=lambda item: item[1])

            # Check how many items of this modality can be supported by
1678
1679
1680
1681
1682
1683
            # the encoder budget.
            encoder_budget = min(self.max_num_encoder_input_tokens,
                                 self.encoder_cache_size)

            max_num_mm_items_encoder_budget = cdiv(encoder_budget,
                                                   max_tokens_per_mm_item)
1684
1685
1686

            # Check how many items of this modality can be supported by
            # the decoder budget.
1687
1688
            max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
                self.model_config)[dummy_data_modality]
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698

            # NOTE: We do not consider max_num_batched_tokens on purpose
            # because the multimodal embeddings can be generated in advance
            # and chunked prefilled.
            max_num_mm_items_decoder_budget = self.max_num_reqs * \
                max_mm_items_per_req

            max_num_mm_items = min(max_num_mm_items_encoder_budget,
                                   max_num_mm_items_decoder_budget)

1699
1700
1701
1702
1703
1704
            logger.info(
                "Encoder cache will be initialized with a budget of %s tokens,"
                " and profiled with %s %s items of the maximum feature size.",
                encoder_budget, max_num_mm_items, dummy_data_modality)

            # Create dummy batch of multimodal inputs.
1705
            dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data(
1706
1707
                model_config=self.model_config,
                seq_len=self.max_num_tokens,
1708
1709
1710
1711
                mm_counts={
                    dummy_data_modality: 1
                },
            ).multi_modal_data
1712

1713
            batched_dummy_mm_inputs = MultiModalKwargs.batch(
1714
                [dummy_mm_kwargs] * max_num_mm_items)
1715
1716
1717
1718
1719
1720
            batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
                batched_dummy_mm_inputs, device=self.device)

            # Run multimodal encoder.
            dummy_encoder_outputs = self.model.get_multimodal_embeddings(
                **batched_dummy_mm_inputs)
1721
1722
1723
1724
1725

            sanity_check_mm_encoder_outputs(
                dummy_encoder_outputs,
                expected_num_items=max_num_mm_items,
            )
1726
1727
1728
1729

            # Cache the dummy encoder outputs.
            self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))

1730
1731
1732
1733
1734
1735
1736
1737
        hidden_states = self._dummy_run(self.max_num_tokens)
        if get_pp_group().is_last_rank:
            sampler_output = self._dummy_sampler_run(hidden_states)
        else:
            sampler_output = None
        torch.cuda.synchronize()
        del hidden_states, sampler_output
        self.encoder_cache.clear()
1738
        gc.collect()
1739
1740

    def capture_model(self) -> None:
1741
1742
        if not self.use_cuda_graph:
            logger.warning(
1743
                "Skipping CUDA graph capture. Please add "
1744
                "-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE)
1745
1746
1747
1748
1749
            return

        start_time = time.perf_counter()
        start_free_gpu_memory = torch.cuda.mem_get_info()[0]

1750
1751
1752
        # Trigger CUDA graph capture for specific shapes.
        # Capture the large shapes first so that the smaller shapes
        # can reuse the memory pool allocated for the large shapes.
1753
        with graph_capture(device=self.device):
1754
            skip_attn = not self.vllm_config.compilation_config.full_cuda_graph
1755
            for num_tokens in reversed(self.cudagraph_batch_sizes):
1756
1757
                for _ in range(self.vllm_config.compilation_config.
                               cudagraph_num_of_warmups):
1758
1759
                    self._dummy_run(num_tokens, skip_attn=skip_attn)
                self._dummy_run(num_tokens, skip_attn=skip_attn)
1760
1761
1762
1763
1764
1765
1766
1767

        end_time = time.perf_counter()
        end_free_gpu_memory = torch.cuda.mem_get_info()[0]
        elapsed_time = end_time - start_time
        cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
        # This usually takes 5~20 seconds.
        logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
                    elapsed_time, cuda_graph_size / (1 << 30))
1768

1769
1770
1771
1772
    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize KV cache based on `kv_cache_config`.
        Args:
1773
            kv_cache_config: Configuration for the KV cache, including the KV
1774
1775
            cache size of each layer
        """
1776
        if len(kv_cache_config.kv_cache_groups) > 1:
1777
1778
1779
            raise NotImplementedError(
                "Hybrid models with more than one KV cache type are not "
                "supported yet.")
1780
        self.kv_cache_config = kv_cache_config
1781

1782
        kv_caches: dict[str, torch.Tensor] = {}
1783

1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
        for kv_cache_group in kv_cache_config.kv_cache_groups:
            kv_cache_spec = kv_cache_group.kv_cache_spec
            for layer_name in kv_cache_group.layer_names:
                tensor_config = kv_cache_config.tensors[layer_name]
                assert tensor_config.size % kv_cache_spec.page_size_bytes == 0
                num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes
                # `num_blocks` is the number of blocks the model runner can use.
                # `kv_cache_config.num_blocks` is the number of blocks that
                # KVCacheManager may allocate.
                # Since different GPUs may have different number of layers and
                # different memory capacities, `num_blocks` can be different on
                # different GPUs, and `kv_cache_config.num_blocks` is set to
                # the min of all `num_blocks`. Verify it here.
                assert num_blocks >= kv_cache_config.num_blocks
1798
                if isinstance(kv_cache_spec, AttentionSpec):
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
                    kv_cache_shape = self.attn_backend.get_kv_cache_shape(
                        num_blocks, kv_cache_spec.block_size,
                        kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
                    dtype = kv_cache_spec.dtype
                    kv_caches[layer_name] = torch.zeros(kv_cache_shape,
                                                        dtype=dtype,
                                                        device=self.device)
                else:
                    # TODO: add new branches when introducing more types of
                    # KV cache specs.
                    raise ValueError("Unknown KV cache spec type.")
1810

1811
        bind_kv_cache(
1812
            kv_caches,
1813
            self.vllm_config.compilation_config.static_forward_context,
1814
1815
            self.kv_caches)

1816
1817
1818
1819
1820
        self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
            weakref.proxy(self),
            kv_cache_config.kv_cache_groups[0].kv_cache_spec,
            self.input_batch.block_table)

1821
    def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
1822
        """
1823
        Generates the KVCacheSpec by parsing the kv cache format from each
1824
1825
        Attention module in the static forward context.
        Returns:
1826
            KVCacheSpec: A dictionary mapping layer names to their KV cache
1827
1828
1829
            format. Layers that do not need KV cache are not included.
        """

1830
        layers = get_layers_from_vllm_config(self.vllm_config, Attention)
1831
        block_size = self.vllm_config.cache_config.block_size
1832
        use_mla = self.vllm_config.model_config.use_mla
1833
        kv_cache_spec: dict[str, KVCacheSpec] = {}
1834
1835
        for layer_name, attn_module in layers.items():
            # TODO: Support other attention modules, e.g., cross-attention
1836
            if attn_module.attn_type == AttentionType.DECODER:
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
                if attn_module.sliding_window is not None:
                    kv_cache_spec[layer_name] = SlidingWindowSpec(
                        block_size=block_size,
                        num_kv_heads=attn_module.num_kv_heads,
                        head_size=attn_module.head_size,
                        dtype=self.kv_cache_dtype,
                        sliding_window=attn_module.sliding_window,
                        use_mla=use_mla)
                else:
                    kv_cache_spec[layer_name] = FullAttentionSpec(
                        block_size=block_size,
                        num_kv_heads=attn_module.num_kv_heads,
                        head_size=attn_module.head_size,
                        dtype=self.kv_cache_dtype,
                        use_mla=use_mla)
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
            elif attn_module.attn_type in (AttentionType.ENCODER,
                                           AttentionType.ENCODER_ONLY):
                # encoder-only attention does not need KV cache.
                continue
            elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
                raise NotImplementedError
            else:
                raise ValueError(
                    f"Unknown attention type: {attn_module.attn_type}")

        return kv_cache_spec