gpu_model_runner.py 65.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import gc
4
import time
5
import weakref
6
from typing import TYPE_CHECKING, Optional, Union
7
8
9
10
11
12

import numpy as np
import torch
import torch.distributed
import torch.nn as nn

13
from vllm.attention import AttentionType, get_attn_backend
14
from vllm.attention.layer import Attention
15
from vllm.config import CompilationLevel, VllmConfig
16
from vllm.distributed.parallel_state import get_pp_group, graph_capture
17
from vllm.forward_context import set_forward_context
18
from vllm.inputs import INPUT_REGISTRY
19
from vllm.logger import init_logger
20
from vllm.model_executor.layers.fused_moe import FusedMoE
21
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
22
from vllm.model_executor.model_loader import get_model
23
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
24
from vllm.multimodal.utils import group_mm_inputs_by_modality
25
from vllm.sampling_params import SamplingType
26
from vllm.sequence import IntermediateTensors
27
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
28
                        LayerBlockType, cdiv, is_pin_memory_available)
29
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
30
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
31
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
32
33
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
                                        KVCacheSpec)
34
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
35
from vllm.v1.sample.metadata import SamplingMetadata
36
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler
37
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
38
from vllm.v1.utils import bind_kv_cache
39
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
40
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
41
42

if TYPE_CHECKING:
43
    from vllm.v1.core.scheduler_output import SchedulerOutput
44
45
46
47

logger = init_logger(__name__)


48
class GPUModelRunner(LoRAModelRunnerMixin):
49
50
51

    def __init__(
        self,
52
        vllm_config: VllmConfig,
53
        device: torch.device,
54
    ):
55
56
57
58
59
60
61
62
63
64
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.prompt_adapter_config = vllm_config.prompt_adapter_config
        self.observability_config = vllm_config.observability_config
65

66
67
68
69
        model_config = self.model_config
        cache_config = self.cache_config
        scheduler_config = self.scheduler_config
        parallel_config = self.parallel_config
70
        self.device = device
71
72
73
74
75
76
77
78
        self.pin_memory = is_pin_memory_available()
        self.dtype = self.model_config.dtype
        if cache_config.cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        else:
            self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
                cache_config.cache_dtype]

79
        self.is_multimodal_model = model_config.is_multimodal_model
80
81
82
83
84
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.max_model_len = model_config.max_model_len
        self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
        self.max_num_tokens = scheduler_config.max_num_batched_tokens
85
        self.max_num_reqs = scheduler_config.max_num_seqs
86
87

        # Model-related.
88
89
        self.num_attn_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)
90
91
        self.num_query_heads = model_config.get_num_attention_heads(
            parallel_config)
92
93
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
        self.head_size = model_config.get_head_size()
94
95
        self.hidden_size = model_config.get_hidden_size()

96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        self.attn_backend = get_attn_backend(
            self.head_size,
            self.dtype,
            self.kv_cache_dtype,
            self.block_size,
            self.model_config.is_attention_free,
            use_mla=self.model_config.use_mla,
        )
        if self.attn_backend is None:
            error_msg = (
                f"Error with get_att_backend: {self.head_size=}, "
                f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
                f"{self.model_config.is_attention_free=}, "
                f"{self.model_config.use_mla=}")
            logger.error(error_msg)
            raise NotImplementedError(
                "Non-Attention backend is not supported by V1 GPUModelRunner.")

        self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
            weakref.proxy(self))

117
        # Multi-modal data support
118
119
        self.input_registry = INPUT_REGISTRY
        self.mm_registry = MULTIMODAL_REGISTRY
120
        self.uses_mrope = model_config.uses_mrope
121

122
123
124
125
126
127
128
        if self.is_multimodal_model:
            # NOTE: Initialized client is only used for processing dummy
            # multimodal data into multimodal kwargs for GPU memory profiling.
            # Only applicable to multimodal models with legacy input mapper.
            self.mm_input_mapper_profiling = MMInputCacheClient(
                self.model_config)
            self.mm_input_mapper_profiling.use_cache = False
129

130
131
132
133
134
135
        encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
            model_config=model_config,
            scheduler_config=scheduler_config,
        )
        self.max_num_encoder_input_tokens = encoder_compute_budget
        self.encoder_cache_size = encoder_cache_size
136
137
138

        # Lazy initialization
        # self.model: nn.Module  # Set after load_model
139
        self.kv_caches: list[torch.Tensor] = []
140
        # req_id -> (input_id -> encoder_output)
141
        self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
142

143
144
145
        # Set up speculative decoding.
        self.use_spec_decode = False
        if self.speculative_config:
146
            self.use_spec_decode = True
147
            self.rejection_sampler = RejectionSampler()
148
149
150
            # TODO: find a better way to check if we are using ngram.
            assert self.speculative_config.ngram_prompt_lookup_min, \
                    "Currently, only ngram spec decode is supported in V1."
151
152
153
154
155
156
157
158
159
            if get_pp_group().is_last_rank:
                self.drafter = NgramProposer()
                # Trigger Numba JIT compilation for N-gram proposer.
                # This usually takes less than 1 second.
                self.drafter.propose(
                    np.zeros(1024, dtype=np.int32),
                    self.speculative_config.ngram_prompt_lookup_min,
                    self.speculative_config.num_speculative_tokens,
                )
160

161
        # Request states.
162
        self.requests: dict[str, CachedRequestState] = {}
163
164
        # Persistent batch.
        self.input_batch = InputBatch(
165
            max_num_reqs=self.max_num_reqs,
166
167
168
169
            max_model_len=self.max_model_len,
            max_num_blocks_per_req=self.max_num_blocks_per_req,
            device=self.device,
            pin_memory=self.pin_memory,
170
            vocab_size=model_config.get_vocab_size(),
171
172
        )

173
        self.use_cuda_graph = (self.vllm_config.compilation_config.level
174
175
176
                               == CompilationLevel.PIECEWISE
                               and not self.model_config.enforce_eager)
        # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
177
178
179
180
        # The convention is different.
        # self.cudagraph_batch_sizes sorts in ascending order.
        # The batch sizes in the config are in descending order.
        self.cudagraph_batch_sizes = list(
181
182
            reversed(
                self.vllm_config.compilation_config.cudagraph_capture_sizes))
183

184
185
186
187
        # Cache the device properties.
        self.device_properties = torch.cuda.get_device_properties(self.device)
        self.num_sms = self.device_properties.multi_processor_count

188
189
190
191
        # Persistent buffers for CUDA graphs.
        self.input_ids = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int32,
                                     device=self.device)
192
193
194
        self.positions = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int64,
                                     device=self.device)
195
196
        # None in the first PP rank. The rest are set after load_model.
        self.intermediate_tensors: Optional[IntermediateTensors] = None
197
198

        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
199
        if self.uses_mrope:
Roger Wang's avatar
Roger Wang committed
200
201
202
203
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
204
205
206
207
208
209

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
Roger Wang's avatar
Roger Wang committed
210
            self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1),
211
212
                                               dtype=torch.int64,
                                               device=self.device)
Roger Wang's avatar
Roger Wang committed
213
214
215
216
217
            self.mrope_positions_cpu = torch.zeros(
                (3, self.max_num_tokens + 1),
                dtype=torch.int64,
                device="cpu",
                pin_memory=self.pin_memory)
218

219
220
221
222
        self.inputs_embeds = torch.zeros(
            (self.max_num_tokens, self.hidden_size),
            dtype=self.dtype,
            device=self.device)
223

224
        # OPTIMIZATION: Cache the tensors rather than creating them every step.
225
        self.arange_np = np.arange(max(self.max_num_reqs + 1,
226
227
                                       self.max_model_len,
                                       self.max_num_tokens),
228
229
230
231
                                   dtype=np.int32)
        # NOTE(woosuk): These tensors are "stateless", i.e., they are literally
        # a faster version of creating a new tensor every time. Thus, we should
        # not make any assumptions about the values in these tensors.
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
        self.input_ids_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int32,
                                         device="cpu",
                                         pin_memory=self.pin_memory)
        self.input_ids_np = self.input_ids_cpu.numpy()
        self.positions_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int64,
                                         device="cpu",
                                         pin_memory=self.pin_memory)
        self.positions_np = self.positions_cpu.numpy()
        self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
                                            dtype=torch.int32,
                                            device="cpu",
                                            pin_memory=self.pin_memory)
        self.slot_mapping_np = self.slot_mapping_cpu.numpy()
        self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
                                               dtype=torch.int32,
                                               device="cpu",
                                               pin_memory=self.pin_memory)
        self.query_start_loc_np = self.query_start_loc_cpu.numpy()
252
253
254
255
256
        self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
                                        dtype=torch.int32,
                                        device="cpu",
                                        pin_memory=self.pin_memory)
        self.seq_lens_np = self.seq_lens_cpu.numpy()
257

258
    def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
259
260
261
262
263
264
        """Update the cached states and the persistent batch with the scheduler
        output.

        The updated states are used by the `_prepare_inputs` function to create
        the input GPU tensors for the model.

265
266
        The SamplingMetadata is updated and copied to the GPU if there is a
        new/resumed/paused/finished request in the batch.
267
268
        """
        # Remove finished requests from the cached states.
269
270
        for req_id in scheduler_output.finished_req_ids:
            self.requests.pop(req_id, None)
271
            self.encoder_cache.pop(req_id, None)
272
273
274
275
276
277
        # Remove the finished requests from the persistent batch.
        # NOTE(woosuk): There could be an edge case where finished_req_ids and
        # scheduled_req_ids overlap. This happens when a request is aborted and
        # then resubmitted with the same ID. In this case, we treat them as two
        # distinct requests - clearing the cached states for the first request
        # and handling the second as a new request.
278
        removed_req_indices: list[int] = []
279
280
281
282
        for req_id in scheduler_output.finished_req_ids:
            req_index = self.input_batch.remove_request(req_id)
            if req_index is not None:
                removed_req_indices.append(req_index)
283
284
285
286
287
288
289
290

        # Free the cached encoder outputs.
        for req_id, input_id in scheduler_output.free_encoder_input_ids:
            encoder_outputs = self.encoder_cache.get(req_id)
            if encoder_outputs is not None:
                encoder_outputs.pop(input_id, None)
                if not encoder_outputs:
                    self.encoder_cache.pop(req_id, None)
291

292
293
294
295
296
297
298
299
300
301
302
303
304
        # Remove the unscheduled requests from the persistent batch.
        # NOTE(woosuk): The unscheduled requests are either preempted requests
        # or running requests that are not scheduled in this step. We remove
        # them from the persistent batch but keep their cached states since
        # they will be scheduled again sometime in the future.
        scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
        cached_req_ids = self.input_batch.req_id_to_index.keys()
        unscheduled_req_ids = cached_req_ids - scheduled_req_ids
        # NOTE(woosuk): The persistent batch optimization assumes that
        # consecutive batches contain mostly the same requests. If batches
        # have low request overlap (e.g., alternating between two distinct
        # sets of requests), this optimization becomes very inefficient.
        for req_id in unscheduled_req_ids:
305
            req_index = self.input_batch.remove_request(req_id)
306
307
            assert req_index is not None
            removed_req_indices.append(req_index)
308

309
        req_ids_to_add: list[str] = []
310
        # Add new requests to the cached states.
311
312
313
        for new_req_data in scheduler_output.scheduled_new_reqs:
            req_id = new_req_data.req_id
            sampling_params = new_req_data.sampling_params
314
            if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
315
316
317
318
319
                generator = torch.Generator(device=self.device)
                generator.manual_seed(sampling_params.seed)
            else:
                generator = None

320
321
            self.requests[req_id] = CachedRequestState(
                req_id=req_id,
322
323
324
325
                prompt_token_ids=new_req_data.prompt_token_ids,
                prompt=new_req_data.prompt,
                mm_inputs=new_req_data.mm_inputs,
                mm_positions=new_req_data.mm_positions,
326
327
                sampling_params=sampling_params,
                generator=generator,
328
329
                block_ids=new_req_data.block_ids,
                num_computed_tokens=new_req_data.num_computed_tokens,
330
                output_token_ids=[],
331
                lora_request=new_req_data.lora_request,
332
            )
333
334

            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
335
            if self.uses_mrope:
336
337
                image_grid_thw = []
                video_grid_thw = []
Roger Wang's avatar
Roger Wang committed
338
                second_per_grid_ts = []
339
340
341
342
343
344
345
                for mm_input in self.requests[req_id].mm_inputs:
                    if mm_input.get("image_grid_thw") is not None:
                        image_grid_thw.extend(
                            mm_input["image_grid_thw"].tolist())
                    if mm_input.get("video_grid_thw") is not None:
                        video_grid_thw.extend(
                            mm_input["video_grid_thw"].tolist())
Roger Wang's avatar
Roger Wang committed
346
347
348
                    if mm_input.get("second_per_grid_ts") is not None:
                        second_per_grid_ts.extend(
                            mm_input["second_per_grid_ts"])
349
350
351
352
353
354
355

                hf_config = self.model_config.hf_config

                self.requests[req_id].mrope_positions, \
                    self.requests[req_id].mrope_position_delta = \
                    MRotaryEmbedding.get_input_positions_tensor(
                        self.requests[req_id].prompt_token_ids,
Roger Wang's avatar
Roger Wang committed
356
                        hf_config=hf_config,
357
358
                        image_grid_thw=image_grid_thw,
                        video_grid_thw=video_grid_thw,
Roger Wang's avatar
Roger Wang committed
359
                        second_per_grid_ts=second_per_grid_ts,
360
361
                    )

362
363
            req_ids_to_add.append(req_id)

364
365
366
        # Update the states of the running/resumed requests.
        for req_data in scheduler_output.scheduled_cached_reqs:
            req_id = req_data.req_id
367
368
            req_state = self.requests[req_id]

369
            # Update the cached states.
370
371
372
373
374
375
376
            num_computed_tokens = req_data.num_computed_tokens
            req_state.num_computed_tokens = num_computed_tokens
            # Add the sampled token(s) from the previous step (if any).
            # This doesn't include "unverified" tokens like spec decode tokens.
            num_new_tokens = (num_computed_tokens +
                              len(req_data.new_token_ids) -
                              req_state.num_tokens)
377
378
379
380
381
382
            if num_new_tokens == 1:
                # Avoid slicing list in most common case.
                req_state.output_token_ids.append(req_data.new_token_ids[-1])
            elif num_new_tokens > 0:
                req_state.output_token_ids.extend(
                    req_data.new_token_ids[-num_new_tokens:])
383
            # Update the block IDs.
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
            if not req_data.resumed_from_preemption:
                # Append the new blocks to the existing block IDs.
                req_state.block_ids.extend(req_data.new_block_ids)
            else:
                # The request is resumed from preemption.
                # Replace the existing block IDs with the new ones.
                req_state.block_ids = req_data.new_block_ids

            req_index = self.input_batch.req_id_to_index.get(req_id)
            if req_index is None:
                # The request is not in the persistent batch.
                # The request was either preempted and resumed later, or was not
                # scheduled in the previous step and needs to be added again.
                req_ids_to_add.append(req_id)
                continue

            # Update the persistent batch.
            self.input_batch.num_computed_tokens_cpu[req_index] = (
402
                num_computed_tokens)
403
404
            self.input_batch.block_table.append_row(req_data.new_block_ids,
                                                    req_index)
405
406
407
408
409
410
            # Add new_token_ids to token_ids_cpu.
            start_token_index = num_computed_tokens
            end_token_index = num_computed_tokens + len(req_data.new_token_ids)
            self.input_batch.token_ids_cpu[
                req_index,
                start_token_index:end_token_index] = req_data.new_token_ids
411
            self.input_batch.num_tokens_no_spec[req_index] = end_token_index
412
413
            # Add spec_token_ids to token_ids_cpu.
            spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
414
                req_id, ())
415
416
417
418
419
420
421
            if spec_token_ids:
                start_index = end_token_index
                end_token_index += len(spec_token_ids)
                self.input_batch.token_ids_cpu[
                    req_index, start_index:end_token_index] = spec_token_ids
            # NOTE(woosuk): `num_tokens` here may include spec decode tokens.
            self.input_batch.num_tokens[req_index] = end_token_index
422

423
424
        # Check if the batch has changed. If not, we can skip copying the
        # sampling metadata from CPU to GPU.
425
426
        batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0

427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
        # Add the new or resumed requests to the persistent batch.
        # The smaller empty indices are filled first.
        removed_req_indices = sorted(removed_req_indices, reverse=True)
        for req_id in req_ids_to_add:
            req_state = self.requests[req_id]
            if removed_req_indices:
                # Fill the empty index.
                req_index = removed_req_indices.pop()
            else:
                # Append to the end.
                req_index = None
            self.input_batch.add_request(req_state, req_index)

        # Condense the batched states if there are empty indices.
        if removed_req_indices:
            self.input_batch.condense(removed_req_indices)
443

444
445
        if batch_changed:
            self.input_batch.refresh_sampling_metadata()
446

447
    def _prepare_inputs(
448
449
        self,
        scheduler_output: "SchedulerOutput",
450
    ) -> tuple[FlashAttentionMetadata, torch.Tensor]:
451
452
453
454
455
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        assert total_num_scheduled_tokens > 0
        num_reqs = self.input_batch.num_reqs
        assert num_reqs > 0

456
457
458
459
460
461
        # Some attention backends (namely MLA) may want to separate requests
        # based on if the attention computation will be compute-bound or
        # memory-bound. This gives them a hook to do that.
        self.attn_metadata_builder.reorder_batch(self.input_batch,
                                                 scheduler_output)

462
463
        # OPTIMIZATION: Start copying the block table first.
        # This way, we can overlap the copy with the following CPU operations.
464
        self.input_batch.block_table.commit(num_reqs)
465
466
467

        # Get the number of scheduled tokens for each request.
        # TODO: The Python loop can be slow. Optimize.
468
        num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
469
        max_num_scheduled_tokens = 0
470
        for i, req_id in enumerate(self.input_batch.req_ids):
471
            num_tokens = scheduler_output.num_scheduled_tokens[req_id]
472
            num_scheduled_tokens[i] = num_tokens
473
474
475
476
477
            max_num_scheduled_tokens = max(max_num_scheduled_tokens,
                                           num_tokens)

        # Get request indices.
        # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
478
479
        req_indices = np.repeat(self.arange_np[:num_reqs],
                                num_scheduled_tokens)
480
481
482

        # Get batched arange.
        # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
483
484
485
486
487
488
489
490
491
        # Equivalent to but faster than:
        # np.concatenate([np.arange(n) for n in num_scheduled_tokens])
        # Step 1. [2, 5, 3] -> [2, 7, 10]
        cu_num_tokens = np.cumsum(num_scheduled_tokens)
        # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
        cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
                                    num_scheduled_tokens)
        # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
492
493

        # Get positions.
494
        positions_np = self.positions_np[:total_num_scheduled_tokens]
495
496
497
498
        np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
               arange,
               out=positions_np)

499
500
        # Calculate M-RoPE positions.
        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
501
        if self.uses_mrope:
502
503
            self._calc_mrope_positions(scheduler_output)

504
505
506
507
        # Get token indices.
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
        # where M is the max_model_len.
508
509
        token_indices = (positions_np +
                         req_indices * self.input_batch.token_ids_cpu.shape[1])
510

511
512
513
514
        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
        torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
515
                           0,
516
517
                           torch.from_numpy(token_indices),
                           out=self.input_ids_cpu[:total_num_scheduled_tokens])
518
519

        # Calculate the slot mapping.
520
521
522
523
524
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
        # where K is the max_num_blocks_per_req and the block size is 2.
        # NOTE(woosuk): We can't simply use `token_indices // block_size` here
        # because M (max_model_len) is not necessarily divisible by block_size.
525
526
527
528
529
        block_table_indices = (req_indices * self.max_num_blocks_per_req +
                               positions_np // self.block_size)
        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
530
531
        block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
        block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
532
533
534
535
        block_offsets = positions_np % self.block_size
        np.add(block_numbers * self.block_size,
               block_offsets,
               out=self.slot_mapping_np[:total_num_scheduled_tokens])
536
537

        # Prepare the attention metadata.
538
        self.query_start_loc_np[0] = 0
539
        self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
540

541
542
543
        self.seq_lens_np[:num_reqs] = (
            self.input_batch.num_computed_tokens_cpu[:num_reqs] +
            num_scheduled_tokens)
544
545
546
547

        # Copy the tensors to the GPU.
        self.input_ids[:total_num_scheduled_tokens].copy_(
            self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
548
        if self.uses_mrope:
549
550
551
552
553
554
555
556
557
            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
            self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
                self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
                non_blocking=True)
        else:
            # Common case (1D positions)
            self.positions[:total_num_scheduled_tokens].copy_(
                self.positions_cpu[:total_num_scheduled_tokens],
                non_blocking=True)
558
559

        # Prepare for cascade attention if needed.
560
561
562
563
        common_prefix_len = self._compute_cascade_attn_prefix_len(
            num_scheduled_tokens,
            scheduler_output.num_common_prefix_blocks,
        )
564
565
        attn_metadata = self.attn_metadata_builder.build(
            num_reqs=num_reqs,
566
            num_actual_tokens=total_num_scheduled_tokens,
567
            max_query_len=max_num_scheduled_tokens,
568
            common_prefix_len=common_prefix_len,
569
        )
570

571
572
        use_spec_decode = len(
            scheduler_output.scheduled_spec_decode_tokens) > 0
573
        if use_spec_decode:
574
575
            logits_indices = self._calc_spec_decode_metadata(
                scheduler_output, cu_num_tokens)
576
577
578
579
580
581
        else:
            # NOTE(woosuk): Due to chunked prefills, the batch may contain
            # partial requests. While we should not sample any token
            # from these partial requests, we do so for simplicity.
            # We will ignore the sampled tokens from the partial requests.
            # TODO: Support prompt logprobs.
582
            logits_indices = attn_metadata.query_start_loc[1:] - 1
583

584
585
586
587
        # Hot-Swap lora model
        if self.lora_config:
            self.set_active_loras(self.input_batch, num_scheduled_tokens)

588
        return attn_metadata, logits_indices
589

590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
    def _compute_cascade_attn_prefix_len(
        self,
        num_scheduled_tokens: np.ndarray,
        num_common_prefix_blocks: int,
    ) -> int:
        """Compute the length of the common prefix for cascade attention.

        NOTE(woosuk): The common prefix length returned by this function
        represents the length used specifically for cascade attention, not the
        actual number of tokens shared between requests. When cascade attention
        is disabled (use_cascade=False), this function returns 0 even if
        requests share common tokens. Additionally, the common prefix length is
        truncated to a multiple of the block size and may be further truncated
        due to implementation details explained below.

        Args:
            num_scheduled_tokens: Number of tokens scheduled per request.
            num_common_prefix_blocks: Number of shared KV cache blocks.

        Returns:
            int: Length of common prefix in tokens.
        """
        common_prefix_len = num_common_prefix_blocks * self.block_size
        if common_prefix_len == 0:
            # Common case.
            return 0

        # NOTE(woosuk): Cascade attention uses two attention kernels: one
        # for the common prefix and the other for the rest. For the first
        # kernel, we concatenate all the query tokens (possibly from
        # different requests) and treat them as if they are from the same
        # request. Then, we use bi-directional attention to process the
        # common prefix in the KV cache. Importantly, this means that the
        # first kernel does not do any masking.

        # Consider the following example:
        # Request 1's input query: [D, E, X]
        # Request 1's kv cache: [A, B, C, D, E, X]
        # Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
        # Request 2's input query: [E, Y]
        # Request 2's kv cache: [A, B, C, D, E, Y]
        # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])

        # If we use [A, B, C, D, E] as the common prefix, then the
        # first kernel will compute the bi-directional attention between
        # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
        # However, this is wrong because D in Request 1 should not attend to
        # E in the common prefix (i.e., we need masking).
        # To avoid this, [A, B, C, D] should be the common prefix.
        # That is, the common prefix should be capped by the minimum
        # num_computed_tokens among the requests, and plus one to include
        # the first token of the query.

        # In practice, we use [A, B, C] as the common prefix, instead of
        # [A, B, C, D] (i.e., the common prefix is capped by the minimum
        # num_computed_tokens, without plus one).
        # This is because of an implementation detail: We want to always
        # use two kernels for cascade attention. Let's imagine:
        # Request 3's input query: [D]
        # Request 3's kv cache: [A, B, C, D]
        # Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D])
        # If we use [A, B, C, D] as the common prefix for Request 1-3,
        # then Request 3 will be processed only by the first kernel,
        # and the second kernel will get an empty input. While this is not
        # a fundamental problem, our current implementation does not support
        # this case.
        num_reqs = len(num_scheduled_tokens)
        common_prefix_len = min(
            common_prefix_len,
            self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
        # common_prefix_len should be a multiple of the block size.
        common_prefix_len = (common_prefix_len // self.block_size *
                             self.block_size)
663
        use_cascade = self.attn_backend.use_cascade_attention(
664
665
666
667
668
669
670
671
672
673
            common_prefix_len=common_prefix_len,
            query_lens=num_scheduled_tokens,
            num_query_heads=self.num_query_heads,
            num_kv_heads=self.num_kv_heads,
            use_alibi=False,  # FIXME
            use_sliding_window=self.sliding_window is not None,
            num_sms=self.num_sms,
        )
        return common_prefix_len if use_cascade else 0

674
675
    def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
        mrope_pos_ptr = 0
676
        for index, req_id in enumerate(self.input_batch.req_ids):
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
            req = self.requests[req_id]
            assert req.mrope_positions is not None

            num_computed_tokens = \
                self.input_batch.num_computed_tokens_cpu[index]
            num_scheduled_tokens = \
                scheduler_output.num_scheduled_tokens[req_id]
            num_prompt_tokens = len(req.prompt_token_ids)

            if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
                prompt_part_len = max(0,
                                      num_prompt_tokens - num_computed_tokens)
                completion_part_len = max(
                    0, num_scheduled_tokens - prompt_part_len)
            else:
                prompt_part_len = num_scheduled_tokens
                completion_part_len = 0

            assert num_scheduled_tokens == prompt_part_len + completion_part_len

            if prompt_part_len > 0:
                # prompt's mrope_positions are pre-computed
                dst_start = mrope_pos_ptr
                dst_end = mrope_pos_ptr + prompt_part_len
                src_start = num_computed_tokens
                src_end = num_computed_tokens + prompt_part_len

                self.mrope_positions_cpu[:, dst_start:dst_end] = \
                    req.mrope_positions[:,src_start:src_end]

                mrope_pos_ptr += prompt_part_len

            if completion_part_len > 0:
                # compute completion's mrope_positions on-the-fly
                dst_start = mrope_pos_ptr
                dst_end = mrope_pos_ptr + completion_part_len

                self.mrope_positions_cpu[:, dst_start:dst_end] = \
                    MRotaryEmbedding.get_next_input_positions_tensor(
                        req.mrope_position_delta,
                        context_len=num_computed_tokens +
                        prompt_part_len,
                        seq_len=num_computed_tokens +
                        prompt_part_len +
                        completion_part_len,
                    )

                mrope_pos_ptr += completion_part_len

726
727
728
729
    def _calc_spec_decode_metadata(
        self,
        scheduler_output: "SchedulerOutput",
        cu_num_tokens: np.ndarray,
730
    ) -> torch.Tensor:
731
732
733
        # Get the number of spec decode tokens for each request.
        num_reqs = self.input_batch.num_reqs
        num_spec_decode_tokens = np.empty(num_reqs, dtype=np.int32)
734
        for i, req_id in enumerate(self.input_batch.req_ids):
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
            num_spec_decode_tokens[i] = len(
                scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))

        # Get spec decode logits indices.
        # E.g.,   num_scheduled_tokens: [4, 100, 3,   100, 2]
        #         cu_num_tokens:        [4, 104, 107, 207, 209]
        #         num_spec_tokens_list: [3, 0,   2,   0,   1]
        #         num_sampled_tokens:   [4, 1,   3,   1,   2]
        #         spec_decode_logits_indices:
        #                 [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
        num_sampled_tokens = num_spec_decode_tokens + 1
        # logits_start_loc: [0, 103, 104, 206, 207]
        logits_start_loc = cu_num_tokens - num_sampled_tokens
        # [0, 103, 104, 206, 207] ->
        #               [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
        logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens)
        # The following three lines:
        # [4, 1,   3,   1,   2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
        # Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11]
        cu_num_sampled_tokens = np.cumsum(num_sampled_tokens)
        # Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9]
        #         -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
        cumsums_sampled_offsets = np.repeat(
            cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens)
        # Step 3.  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
        #       -  [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
        #      -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
        total_num_sampled_tokens = num_sampled_tokens.sum()
        sampled_arange = (self.arange_np[:total_num_sampled_tokens] -
                          cumsums_sampled_offsets)

        # [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] ->
        # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
        spec_decode_logits_indices = logits_start_loc + sampled_arange
        return torch.from_numpy(spec_decode_logits_indices).to(
            self.device, non_blocking=True)

772
773
774
775
776
777
    def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
        scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
        if not scheduled_encoder_inputs:
            return

        # Batch the multi-modal inputs.
778
779
        mm_inputs: list[MultiModalKwargs] = []
        req_input_ids: list[tuple[str, int]] = []
780
781
782
783
784
        for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
            req_state = self.requests[req_id]
            for input_id in encoder_input_ids:
                mm_inputs.append(req_state.mm_inputs[input_id])
                req_input_ids.append((req_id, input_id))
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812

        # Batch mm inputs as much as we can: if a request in the batch has
        # multiple modalities or a different modality than the previous one,
        # we process it separately to preserve item order.
        # FIXME(ywang96): This is a hacky way to deal with multiple modalities
        # in the same batch while still being able to benefit from batching
        # multimodal inputs. The proper solution should be reordering the
        # encoder outputs.
        grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)

        encoder_outputs = []
        for grouped_mm_inputs in grouped_mm_inputs_list:
            batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
            batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
                                                           device=self.device)

            # Run the encoder.
            # `curr_group_outputs` is either of the following:
            # 1. A tensor of shape (num_items, feature_size, hidden_size)
            # in case feature_size is fixed across all multimodal items.
            # 2. A list or tuple (length: num_items) of tensors, each of shape
            # (feature_size, hidden_size) in case the feature size is dynamic
            # depending on the input multimodal items.
            curr_group_outputs = self.model.get_multimodal_embeddings(
                **batched_mm_inputs)

            for output in curr_group_outputs:
                encoder_outputs.append(output)
813
814
815
816
817
818
819
820
821
822

        # Cache the encoder outputs.
        for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
            if req_id not in self.encoder_cache:
                self.encoder_cache[req_id] = {}
            self.encoder_cache[req_id][input_id] = output

    def _gather_encoder_outputs(
        self,
        scheduler_output: "SchedulerOutput",
823
824
    ) -> list[torch.Tensor]:
        encoder_outputs: list[torch.Tensor] = []
825
        for req_id in self.input_batch.req_ids:
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
                req_id]
            req_state = self.requests[req_id]
            num_computed_tokens = req_state.num_computed_tokens
            mm_positions = req_state.mm_positions
            for i, pos_info in enumerate(mm_positions):
                start_pos = pos_info["offset"]
                num_encoder_tokens = pos_info["length"]

                # The encoder output is needed if the two ranges overlap:
                # [num_computed_tokens,
                #  num_computed_tokens + num_scheduled_tokens) and
                # [start_pos, start_pos + num_encoder_tokens)
                if start_pos >= num_computed_tokens + num_scheduled_tokens:
                    # The encoder output is not needed in this step.
                    break
                if start_pos + num_encoder_tokens <= num_computed_tokens:
                    # The encoder output is already processed and stored
                    # in the decoder's KV cache.
                    continue

                start_idx = max(num_computed_tokens - start_pos, 0)
                end_idx = min(
                    num_computed_tokens - start_pos + num_scheduled_tokens,
                    num_encoder_tokens)
                assert start_idx < end_idx
                assert req_id in self.encoder_cache
                assert i in self.encoder_cache[req_id]
                encoder_output = self.encoder_cache[req_id][i]
                encoder_outputs.append(encoder_output[start_idx:end_idx])
        return encoder_outputs

858
859
860
    def get_model(self) -> nn.Module:
        return self.model

861
862
863
864
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
865
        intermediate_tensors: Optional[IntermediateTensors] = None,
866
    ) -> Union[ModelRunnerOutput, torch.Tensor]:
867
        self._update_states(scheduler_output)
868

869
870
871
872
873
874
        if self.is_multimodal_model:
            # Run the multimodal encoder if any.
            self._execute_encoder(scheduler_output)
            encoder_outputs = self._gather_encoder_outputs(scheduler_output)
        else:
            encoder_outputs = []
875
876

        # Prepare the decoder inputs.
877
        attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
878
879
880
881
882
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        if (self.use_cuda_graph
                and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
            # Use piecewise CUDA graphs.
            # Add padding to the batch size.
883
            num_input_tokens = self.vllm_config.pad_for_cudagraph(
884
885
886
887
                num_scheduled_tokens)
        else:
            # Eager mode.
            num_input_tokens = num_scheduled_tokens
888
889
        attn_metadata.num_input_tokens = num_input_tokens

890
891
892
893
894
895
896
897
898
899
900
901
902
903
        if self.is_multimodal_model:
            # NOTE(woosuk): To unify token ids and soft tokens (vision
            # embeddings), we always use embeddings (rather than token ids)
            # as input to the multimodal model, even when the input is text.
            input_ids = self.input_ids[:num_scheduled_tokens]
            if encoder_outputs:
                inputs_embeds = self.model.get_input_embeddings(
                    input_ids, encoder_outputs)
            else:
                inputs_embeds = self.model.get_input_embeddings(input_ids)
            # TODO(woosuk): Avoid the copy. Optimize.
            self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
            input_ids = None
904
        else:
905
906
907
908
909
910
            # For text-only models, we use token ids as input.
            # While it is possible to use embeddings as input just like the
            # multimodal models, it is not desirable for performance since
            # then the embedding layer is not included in the CUDA graph.
            input_ids = self.input_ids[:num_input_tokens]
            inputs_embeds = None
911
912
913
914
        if self.uses_mrope:
            positions = self.mrope_positions[:, :num_input_tokens]
        else:
            positions = self.positions[:num_input_tokens]
915

916
917
918
        if get_pp_group().is_first_rank:
            intermediate_tensors = None
        else:
919
920
921
922
923
            assert intermediate_tensors is not None
            assert self.intermediate_tensors is not None
            for k, v in intermediate_tensors.items():
                self.intermediate_tensors[k][:num_input_tokens].copy_(
                    v[:num_input_tokens], non_blocking=True)
924
925
926
927
928
            intermediate_tensors = IntermediateTensors({
                k: v[:num_input_tokens]
                for k, v in self.intermediate_tensors.items()
            })

929
930
        # Run the decoder.
        # Use persistent buffers for CUDA graphs.
931
        with set_forward_context(attn_metadata, self.vllm_config):
932
            hidden_states = self.model(
933
                input_ids=input_ids,
934
                positions=positions,
935
                intermediate_tensors=intermediate_tensors,
936
                inputs_embeds=inputs_embeds,
937
            )
938
        if not get_pp_group().is_last_rank:
939
            # For mid-pipeline stages, return the hidden states.
940
            return hidden_states
941

942
        hidden_states = hidden_states[:num_scheduled_tokens]
943
944
        sample_hidden_states = hidden_states[logits_indices]
        logits = self.model.compute_logits(sample_hidden_states, None)
945
946

        # Sample the next token and get logprobs if needed.
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
        sampling_metadata = self.input_batch.sampling_metadata
        if not self.use_spec_decode:
            sampler_output = self.model.sample(
                logits=logits,
                sampling_metadata=sampling_metadata,
            )
        else:
            target_probs = self.model.sampler.compute_probs(
                logits, sampling_metadata)
            scheduled_request_ids = scheduler_output.num_scheduled_tokens.keys(
            )
            draft_token_ids = [
                scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
                for req_id in scheduled_request_ids
            ]
            sampler_output = self.rejection_sampler(draft_token_ids,
                                                    target_probs,
                                                    sampling_metadata)
965
966
967

        # TODO(woosuk): The following loop can be slow since it iterates over
        # the requests one by one. Optimize.
968
        for i, req_id in enumerate(self.input_batch.req_ids):
969
970
971
            req_state = self.requests[req_id]
            seq_len = (req_state.num_computed_tokens +
                       scheduler_output.num_scheduled_tokens[req_id])
972
973
            if seq_len < req_state.num_tokens:
                # Ignore the sampled token.
974
                # Rewind the generator state as if the token was not sampled.
975
                generator = self.input_batch.generators.get(i)
976
                if generator is not None:
977
978
                    # This relies on cuda-specific torch-internal impl details
                    generator.set_offset(generator.get_offset() - 4)
979

980
981
        # NOTE: GPU -> CPU Sync happens here.
        # Move as many CPU operations as possible before this sync point.
982
983
984
985
986
987
988
989
990
991
        logprobs_tensors = sampler_output.logprobs_tensors
        logprobs_lists = logprobs_tensors.tolists() \
            if logprobs_tensors is not None else None

        # Compute prompt logprobs if needed.
        prompt_logprobs_dict = self._get_prompt_logprobs_dict(
            hidden_states,
            scheduler_output,
        )

992
        # Get the valid generated tokens.
993
994
995
        sampled_token_ids = sampler_output.sampled_token_ids
        max_gen_len = sampled_token_ids.shape[-1]
        if max_gen_len == 1:
996
            # No spec decode tokens.
997
998
            valid_sampled_token_ids = sampled_token_ids.tolist()
        else:
999
            # Includes spec decode tokens.
1000
1001
            valid_mask = sampled_token_ids != INVALID_TOKEN_ID
            gen_lens = valid_mask.sum(dim=1).tolist()
1002
            # TODO(woosuk): Optimize this.
1003
1004
1005
1006
            valid_sampled_token_ids = [
                seq.tolist()
                for seq in sampled_token_ids[valid_mask].split(gen_lens)
            ]
1007

1008
1009
1010
1011
1012
1013
        if not self.use_spec_decode:
            spec_token_ids = None
        else:
            spec_token_ids = self.generate_draft_token_ids(
                valid_sampled_token_ids)

1014
        model_runner_output = ModelRunnerOutput(
1015
            req_ids=self.input_batch.req_ids,
1016
            req_id_to_index=self.input_batch.req_id_to_index,
1017
            sampled_token_ids=valid_sampled_token_ids,
1018
            spec_token_ids=spec_token_ids,
1019
1020
            logprobs=logprobs_lists,
            prompt_logprobs_dict=prompt_logprobs_dict,
1021
1022
1023
        )
        return model_runner_output

1024
1025
    def generate_draft_token_ids(
        self,
1026
1027
        sampled_token_ids: list[list[int]],
    ) -> list[list[int]]:
1028
        # TODO(woosuk): Optimize.
1029
        draft_token_ids: list[list[int]] = []
1030
1031
1032
        for i, sampled_ids in enumerate(sampled_token_ids):
            num_sampled_ids = len(sampled_ids)
            if not num_sampled_ids:
1033
1034
1035
1036
1037
1038
                # Skip speculative decoding.
                draft_token_ids.append([])
                continue

            # Add sampled_token_ids to token_ids_cpu.
            start_idx = self.input_batch.num_tokens_no_spec[i]
1039
1040
            end_idx = start_idx + num_sampled_ids
            self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
            drafter_output = self.drafter.propose(
                self.input_batch.token_ids_cpu[i, :end_idx],
                self.speculative_config.ngram_prompt_lookup_min,
                self.speculative_config.num_speculative_tokens,
            )
            if drafter_output is None or len(drafter_output) == 0:
                draft_token_ids.append([])
            else:
                draft_token_ids.append(drafter_output.tolist())
        return draft_token_ids

1052
1053
1054
    def load_model(self) -> None:
        logger.info("Starting to load model %s...", self.model_config.model)
        with DeviceMemoryProfiler() as m:  # noqa: SIM117
1055
            time_before_load = time.perf_counter()
Joe Runde's avatar
Joe Runde committed
1056
            self.model = get_model(vllm_config=self.vllm_config)
1057
1058
1059
1060
1061
1062
            if self.lora_config:
                self.model = self.load_lora_model(self.model,
                                                  self.model_config,
                                                  self.scheduler_config,
                                                  self.lora_config,
                                                  self.device)
1063
            time_after_load = time.perf_counter()
1064
        self.model_memory_usage = m.consumed_memory
1065
        logger.info("Model loading took %.4f GB and %.6f seconds",
1066
1067
                    self.model_memory_usage / float(2**30),
                    time_after_load - time_before_load)
1068

1069
1070
1071
1072
    def _get_prompt_logprobs_dict(
        self,
        hidden_states: torch.Tensor,
        scheduler_output: "SchedulerOutput",
1073
    ) -> dict[str, Optional[LogprobsTensors]]:
1074
1075
1076
1077
        num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
        if not num_prompt_logprobs_dict:
            return {}

1078
        prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138

        # Since prompt logprobs are a rare feature, prioritize simple,
        # maintainable loop over optimal performance.
        completed_prefill_reqs = []
        for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():

            num_tokens = scheduler_output.num_scheduled_tokens[req_id]

            # Get metadata for this request.
            request = self.requests[req_id]
            num_prompt_tokens = len(request.prompt_token_ids)
            prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
                self.device, non_blocking=True)

            # Determine number of logits to retrieve.
            start_tok = request.num_computed_tokens + 1
            num_remaining_tokens = num_prompt_tokens - start_tok
            if num_tokens < num_remaining_tokens:
                # This is a chunk, more tokens remain.
                num_logits = num_tokens
            else:
                # This is the last chunk of prompt tokens to return.
                num_logits = num_remaining_tokens
                completed_prefill_reqs.append(req_id)

            # Get the logits corresponding to this req's prompt tokens.
            # If this is a partial request (i.e. chunked prefill),
            # then there is prompt logprob generated for each index.
            req_idx = self.input_batch.req_id_to_index[req_id]
            offset = self.query_start_loc_np[req_idx].item()
            prompt_hidden_states = hidden_states[offset:offset + num_logits]
            logits = self.model.compute_logits(prompt_hidden_states, None)

            # Get the "target" tokens for each index. For prompt at index i,
            # the token at prompt index i+1 is the "sampled" token we want
            # to gather the logprob for.
            tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]

            # Compute prompt logprobs.
            logprobs = self.model.sampler.compute_logprobs(logits)
            token_ids, logprobs, ranks = self.model.sampler.gather_logprobs(
                logprobs, num_prompt_logprobs, tgt_token_ids)

            # Transfer GPU->CPU async.
            prompt_logprobs_dict[req_id] = LogprobsTensors(
                token_ids.to("cpu", non_blocking=True),
                logprobs.to("cpu", non_blocking=True),
                ranks.to("cpu", non_blocking=True),
            )

        # Remove requests that have completed prefill from the batch
        # num_prompt_logprobs_dict.
        for req_id in completed_prefill_reqs:
            del num_prompt_logprobs_dict[req_id]

        # Must synchronize the non-blocking GPU->CPU transfers.
        torch.cuda.synchronize()

        return prompt_logprobs_dict

1139
1140
1141
1142
1143
    @torch.inference_mode()
    def _dummy_run(
        self,
        num_tokens: int,
    ) -> torch.Tensor:
1144
        model = self.model
1145
1146
1147
1148
1149
1150
        if self.is_multimodal_model:
            input_ids = None
            inputs_embeds = self.inputs_embeds[:num_tokens]
        else:
            input_ids = self.input_ids[:num_tokens]
            inputs_embeds = None
1151
1152
1153
1154
        if self.uses_mrope:
            positions = self.mrope_positions[:, :num_tokens]
        else:
            positions = self.positions[:num_tokens]
1155
1156
1157
1158

        if get_pp_group().is_first_rank:
            intermediate_tensors = None
        else:
1159
            if self.intermediate_tensors is None:
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
                self.intermediate_tensors = (
                    self.model.make_empty_intermediate_tensors(
                        batch_size=self.max_num_tokens,
                        dtype=self.model_config.dtype,
                        device=self.device))
            intermediate_tensors = IntermediateTensors({
                k: v[:num_tokens]
                for k, v in self.intermediate_tensors.items()
            })

youkaichao's avatar
youkaichao committed
1170
1171
        with set_forward_context(None, self.vllm_config,
                                 num_tokens=num_tokens):
1172
            hidden_states = model(
1173
                input_ids=input_ids,
1174
                positions=positions,
1175
                intermediate_tensors=intermediate_tensors,
1176
1177
                inputs_embeds=inputs_embeds,
            )
1178
1179
1180
        return hidden_states

    def profile_run(self) -> None:
1181
        # Profile with multimodal encoder & encoder cache.
1182
1183
1184
        # TODO: handle encoder-decoder models once we support them.
        if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
                and self.encoder_cache_size > 0):
1185

1186
            # NOTE: Currently model is profiled with a single non-text
1187
1188
            # modality with the max possible input tokens even when
            # it supports multiple.
1189
1190
1191
            max_tokens_by_modality_dict = (
                MULTIMODAL_REGISTRY.
                get_max_tokens_per_item_by_nonzero_modality(self.model_config))
1192
1193
1194
1195
            dummy_data_modality, max_tokens_per_mm_item = max(
                max_tokens_by_modality_dict.items(), key=lambda item: item[1])

            # Check how many items of this modality can be supported by
1196
1197
1198
1199
1200
1201
            # the encoder budget.
            encoder_budget = min(self.max_num_encoder_input_tokens,
                                 self.encoder_cache_size)

            max_num_mm_items_encoder_budget = cdiv(encoder_budget,
                                                   max_tokens_per_mm_item)
1202
1203
1204

            # Check how many items of this modality can be supported by
            # the decoder budget.
1205
1206
            max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
                self.model_config)[dummy_data_modality]
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216

            # NOTE: We do not consider max_num_batched_tokens on purpose
            # because the multimodal embeddings can be generated in advance
            # and chunked prefilled.
            max_num_mm_items_decoder_budget = self.max_num_reqs * \
                max_mm_items_per_req

            max_num_mm_items = min(max_num_mm_items_encoder_budget,
                                   max_num_mm_items_decoder_budget)

1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
            logger.info(
                "Encoder cache will be initialized with a budget of %s tokens,"
                " and profiled with %s %s items of the maximum feature size.",
                encoder_budget, max_num_mm_items, dummy_data_modality)

            # Create dummy batch of multimodal inputs.
            dummy_request_data = self.input_registry.dummy_data_for_profiling(
                model_config=self.model_config,
                seq_len=self.max_num_tokens,
                mm_registry=self.mm_registry,
            )
            dummy_mm_data = dummy_request_data.multi_modal_data

1230
1231
1232
1233
            # Dummy data definition in V0 may contain multiple multimodal items
            # (e.g, multiple images) for a single request, therefore here we
            # always replicate first item by max_num_mm_items times since in V1
            # they are scheduled to be processed separately.
1234
1235

            # Case when models have a merged processor, their dummy data is
1236
1237
            # already batched `MultiModalKwargs`, therefore we take the first
            # `MultiModalKwargsItem` from the desired modality to profile on.
1238
            if isinstance(dummy_mm_data, MultiModalKwargs):
1239
1240
1241
                dummy_mm_item = dummy_mm_data.get_item(
                    modality=dummy_data_modality, item_index=0)
                dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
1242
1243
1244
1245

            # Case when models have dummy data explicitly defined as
            # `MultiModalDataDict`, so they need to be processed through input
            # mapper.
1246
1247
            # TODO (ywang96): deprecate this path once merged processor is
            # supported on all models.
1248
            else:
1249
                mm_kwargs_list = self.mm_input_mapper_profiling.process_inputs(
1250
                    mm_data=dummy_mm_data,
1251
                    mm_hashes=None,
1252
1253
1254
1255
                    mm_processor_kwargs=None,
                    precomputed_mm_inputs=None)
                dummy_mm_kwargs = mm_kwargs_list[0]

1256
            batched_dummy_mm_inputs = MultiModalKwargs.batch(
1257
                [dummy_mm_kwargs] * max_num_mm_items)
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
            batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
                batched_dummy_mm_inputs, device=self.device)

            # Run multimodal encoder.
            dummy_encoder_outputs = self.model.get_multimodal_embeddings(
                **batched_dummy_mm_inputs)
            assert len(dummy_encoder_outputs) == max_num_mm_items, (
                "Expected dimension 0 of encoder outputs to match the number "
                f"of multimodal data items: {max_num_mm_items}, got "
                f"{len(dummy_encoder_outputs)=} instead. This is most likely "
                "due to the 'get_multimodal_embeddings' method of the model "
                "not implemented correctly.")

            # Cache the dummy encoder outputs.
            self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))

1274
1275
1276
1277
        # For profile, have maximum num_reqs and that collectively have
        # maximum num_tokens.
        num_reqs = self.scheduler_config.max_num_seqs
        num_tokens = self.max_num_tokens
1278
        min_tokens_per_req = num_tokens // num_reqs
1279

1280
        num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
1281
1282
1283
1284
        num_scheduled_tokens_list[-1] += num_tokens % num_reqs
        assert sum(num_scheduled_tokens_list) == num_tokens
        assert len(num_scheduled_tokens_list) == num_reqs

1285
1286
        num_scheduled_tokens = np.array(num_scheduled_tokens_list,
                                        dtype=np.int32)
1287
1288
1289
1290
1291
        logit_indices = np.cumsum(num_scheduled_tokens) - 1

        with self.maybe_profile_with_lora(self.lora_config,
                                          num_scheduled_tokens):
            # Trigger compilation for general shape.
1292
            hidden_states = self._dummy_run(self.max_num_tokens)
1293
1294
            if get_pp_group().is_last_rank:
                hidden_states = hidden_states[logit_indices]
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
                logits = self.model.compute_logits(hidden_states, None)
                dummy_tensors = lambda v: torch.full(
                    (num_reqs, ), v, device=self.device)
                dummy_metadata = SamplingMetadata(
                    temperature=dummy_tensors(0.5),
                    all_greedy=False,
                    all_random=False,
                    top_p=dummy_tensors(0.9),
                    top_k=dummy_tensors(logits.size(1) - 1),
                    min_p=None,
                    generators={},
                    max_num_logprobs=None,
                    no_penalties=True,
                    prompt_token_ids=torch.ones_like(logits,
                                                     dtype=torch.int64),
                    frequency_penalties=dummy_tensors(0.1),
                    presence_penalties=dummy_tensors(0.1),
                    repetition_penalties=dummy_tensors(0.1),
                    output_token_ids=[[] for _ in range(num_reqs)],
                    min_tokens={},
                    logit_bias=[None for _ in range(num_reqs)],
                    allowed_token_ids_mask=None,
                )
                sampler_output = self.model.sample(
                    logits=logits, sampling_metadata=dummy_metadata)
1320
            else:
1321
                logits = None
1322
                sampler_output = None
1323
                dummy_metadata = None
1324
            torch.cuda.synchronize()
1325
            del hidden_states, logits, sampler_output, dummy_metadata
1326
            self.encoder_cache.clear()
1327
        gc.collect()
1328
1329

    def capture_model(self) -> None:
1330
1331
        if not self.use_cuda_graph:
            logger.warning(
1332
                "Skipping CUDA graph capture. Please add "
1333
                "-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE)
1334
1335
1336
1337
1338
            return

        start_time = time.perf_counter()
        start_free_gpu_memory = torch.cuda.mem_get_info()[0]

1339
1340
1341
        # Trigger CUDA graph capture for specific shapes.
        # Capture the large shapes first so that the smaller shapes
        # can reuse the memory pool allocated for the large shapes.
1342
        with graph_capture(device=self.device):
1343
            for num_tokens in reversed(self.cudagraph_batch_sizes):
1344
1345
                for _ in range(self.vllm_config.compilation_config.
                               cudagraph_num_of_warmups):
1346
1347
                    self._dummy_run(num_tokens)
                self._dummy_run(num_tokens)
1348
1349
1350
1351
1352
1353
1354
1355

        end_time = time.perf_counter()
        end_free_gpu_memory = torch.cuda.mem_get_info()[0]
        elapsed_time = end_time - start_time
        cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
        # This usually takes 5~20 seconds.
        logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
                    elapsed_time, cuda_graph_size / (1 << 30))
1356

1357
1358
1359
1360
    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize KV cache based on `kv_cache_config`.
        Args:
1361
            kv_cache_config: Configuration for the KV cache, including the KV
1362
1363
1364
1365
1366
1367
1368
            cache size of each layer
        """
        if len(kv_cache_config.groups) > 1:
            raise NotImplementedError(
                "Hybrid models with more than one KV cache type are not "
                "supported yet.")

1369
        kv_caches: dict[str, torch.Tensor] = {}
1370
1371
1372
1373
1374
1375

        for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
            tensor_config = kv_cache_config.tensors[layer_name]
            assert tensor_config.size % layer_spec.page_size_bytes == 0
            num_blocks = tensor_config.size // layer_spec.page_size_bytes
            if isinstance(layer_spec, FullAttentionSpec):
1376
                kv_cache_shape = self.attn_backend.get_kv_cache_shape(
1377
1378
1379
1380
1381
1382
1383
1384
1385
                    num_blocks, layer_spec.block_size, layer_spec.num_kv_heads,
                    layer_spec.head_size)
                dtype = layer_spec.dtype
                kv_caches[layer_name] = torch.zeros(kv_cache_shape,
                                                    dtype=dtype,
                                                    device=self.device)
            else:
                raise NotImplementedError

1386
        bind_kv_cache(
1387
            kv_caches,
1388
            self.vllm_config.compilation_config.static_forward_context,
1389
1390
1391
1392
            self.kv_caches)

    def get_kv_cache_spec(self) -> KVCacheSpec:
        """
1393
        Generates the KVCacheSpec by parsing the kv cache format from each
1394
1395
        Attention module in the static forward context.
        Returns:
1396
            KVCacheSpec: A dictionary mapping layer names to their KV cache
1397
1398
1399
1400
1401
1402
1403
            format. Layers that do not need KV cache are not included.
        """

        forward_ctx = self.vllm_config.compilation_config.static_forward_context
        block_size = self.vllm_config.cache_config.block_size
        kv_cache_spec: KVCacheSpec = {}
        for layer_name, attn_module in forward_ctx.items():
1404
1405
1406
            if isinstance(attn_module, FusedMoE):
                continue

1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
            # TODO: Support other attention modules, e.g., sliding window,
            # cross-attention, MLA.
            assert isinstance(attn_module, Attention)
            if attn_module.attn_type == AttentionType.DECODER:
                kv_cache_spec[layer_name] = FullAttentionSpec(
                    block_size=block_size,
                    num_kv_heads=attn_module.num_kv_heads,
                    head_size=attn_module.head_size,
                    dtype=attn_module.dtype,
                )
            elif attn_module.attn_type in (AttentionType.ENCODER,
                                           AttentionType.ENCODER_ONLY):
                # encoder-only attention does not need KV cache.
                continue
            elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
                raise NotImplementedError
            else:
                raise ValueError(
                    f"Unknown attention type: {attn_module.attn_type}")

        return kv_cache_spec