gpu_model_runner.py 64.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import gc
4
import time
5
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
6
7
8
9
10
11

import numpy as np
import torch
import torch.distributed
import torch.nn as nn

12
13
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
14
from vllm.config import CompilationLevel, VllmConfig
15
from vllm.distributed.parallel_state import get_pp_group, graph_capture
16
from vllm.forward_context import set_forward_context
17
from vllm.inputs import INPUT_REGISTRY
18
from vllm.logger import init_logger
19
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
20
from vllm.model_executor.model_loader import get_model
21
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
22
from vllm.multimodal.utils import group_mm_inputs_by_modality
23
from vllm.sampling_params import SamplingType
24
from vllm.sequence import IntermediateTensors
25
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
26
                        LayerBlockType, cdiv, is_pin_memory_available)
27
28
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
                                                   FlashAttentionMetadata)
29
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
30
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
31
32
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
                                        KVCacheSpec)
33
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
34
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID
35
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
36
from vllm.v1.utils import bind_kv_cache
37
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
38
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
39
40

if TYPE_CHECKING:
41
    from vllm.v1.core.scheduler_output import SchedulerOutput
42
43
44
45

logger = init_logger(__name__)


46
class GPUModelRunner(LoRAModelRunnerMixin):
47
48
49

    def __init__(
        self,
50
        vllm_config: VllmConfig,
51
        device: torch.device,
52
    ):
53
54
55
56
57
58
59
60
61
62
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.prompt_adapter_config = vllm_config.prompt_adapter_config
        self.observability_config = vllm_config.observability_config
63

64
65
66
67
        model_config = self.model_config
        cache_config = self.cache_config
        scheduler_config = self.scheduler_config
        parallel_config = self.parallel_config
68
        self.device = device
69
70
71
72
73
74
75
76
        self.pin_memory = is_pin_memory_available()
        self.dtype = self.model_config.dtype
        if cache_config.cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        else:
            self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
                cache_config.cache_dtype]

77
        self.is_multimodal_model = model_config.is_multimodal_model
78
79
80
81
82
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.max_model_len = model_config.max_model_len
        self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
        self.max_num_tokens = scheduler_config.max_num_batched_tokens
83
        self.max_num_reqs = scheduler_config.max_num_seqs
84
85

        # Model-related.
86
87
        self.num_attn_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)
88
89
        self.num_query_heads = model_config.get_num_attention_heads(
            parallel_config)
90
91
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
        self.head_size = model_config.get_head_size()
92
93
94
        self.hidden_size = model_config.get_hidden_size()

        # Multi-modal data support
95
96
        self.input_registry = INPUT_REGISTRY
        self.mm_registry = MULTIMODAL_REGISTRY
97
        self.uses_mrope = model_config.uses_mrope
98

99
100
101
102
103
104
105
        if self.is_multimodal_model:
            # NOTE: Initialized client is only used for processing dummy
            # multimodal data into multimodal kwargs for GPU memory profiling.
            # Only applicable to multimodal models with legacy input mapper.
            self.mm_input_mapper_profiling = MMInputCacheClient(
                self.model_config)
            self.mm_input_mapper_profiling.use_cache = False
106

107
108
109
110
111
112
        encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
            model_config=model_config,
            scheduler_config=scheduler_config,
        )
        self.max_num_encoder_input_tokens = encoder_compute_budget
        self.encoder_cache_size = encoder_cache_size
113
114
115
116

        # Lazy initialization
        # self.model: nn.Module  # Set after load_model
        self.kv_caches: List[torch.Tensor] = []
117
118
        # req_id -> (input_id -> encoder_output)
        self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
119

120
121
122
        # Set up speculative decoding.
        self.use_spec_decode = False
        if self.speculative_config:
123
124
            self.use_spec_decode = True

125
126
127
            # TODO: find a better way to check if we are using ngram.
            assert self.speculative_config.ngram_prompt_lookup_min, \
                    "Currently, only ngram spec decode is supported in V1."
128
129
130
131
132
133
134
135
136
            if get_pp_group().is_last_rank:
                self.drafter = NgramProposer()
                # Trigger Numba JIT compilation for N-gram proposer.
                # This usually takes less than 1 second.
                self.drafter.propose(
                    np.zeros(1024, dtype=np.int32),
                    self.speculative_config.ngram_prompt_lookup_min,
                    self.speculative_config.num_speculative_tokens,
                )
137

138
139
140
141
        # Request states.
        self.requests: Dict[str, CachedRequestState] = {}
        # Persistent batch.
        self.input_batch = InputBatch(
142
            max_num_reqs=self.max_num_reqs,
143
144
145
146
            max_model_len=self.max_model_len,
            max_num_blocks_per_req=self.max_num_blocks_per_req,
            device=self.device,
            pin_memory=self.pin_memory,
147
            vocab_size=model_config.get_vocab_size(),
148
149
        )

150
        self.use_cuda_graph = (self.vllm_config.compilation_config.level
151
152
153
                               == CompilationLevel.PIECEWISE
                               and not self.model_config.enforce_eager)
        # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
154
155
156
157
        # The convention is different.
        # self.cudagraph_batch_sizes sorts in ascending order.
        # The batch sizes in the config are in descending order.
        self.cudagraph_batch_sizes = list(
158
159
            reversed(
                self.vllm_config.compilation_config.cudagraph_capture_sizes))
160

161
162
163
164
        # Cache the device properties.
        self.device_properties = torch.cuda.get_device_properties(self.device)
        self.num_sms = self.device_properties.multi_processor_count

165
166
167
168
        # Persistent buffers for CUDA graphs.
        self.input_ids = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int32,
                                     device=self.device)
169
170
171
        self.positions = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int64,
                                     device=self.device)
172
173
        # None in the first PP rank. The rest are set after load_model.
        self.intermediate_tensors: Optional[IntermediateTensors] = None
174
175

        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
176
        if self.uses_mrope:
Roger Wang's avatar
Roger Wang committed
177
178
179
180
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
181
182
183
184
185
186

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
Roger Wang's avatar
Roger Wang committed
187
            self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1),
188
189
                                               dtype=torch.int64,
                                               device=self.device)
Roger Wang's avatar
Roger Wang committed
190
191
192
193
194
            self.mrope_positions_cpu = torch.zeros(
                (3, self.max_num_tokens + 1),
                dtype=torch.int64,
                device="cpu",
                pin_memory=self.pin_memory)
195

196
197
198
199
        self.inputs_embeds = torch.zeros(
            (self.max_num_tokens, self.hidden_size),
            dtype=self.dtype,
            device=self.device)
200

201
        # OPTIMIZATION: Cache the tensors rather than creating them every step.
202
        self.arange_np = np.arange(max(self.max_num_reqs + 1,
203
204
                                       self.max_model_len,
                                       self.max_num_tokens),
205
206
207
208
                                   dtype=np.int32)
        # NOTE(woosuk): These tensors are "stateless", i.e., they are literally
        # a faster version of creating a new tensor every time. Thus, we should
        # not make any assumptions about the values in these tensors.
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
        self.input_ids_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int32,
                                         device="cpu",
                                         pin_memory=self.pin_memory)
        self.input_ids_np = self.input_ids_cpu.numpy()
        self.positions_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int64,
                                         device="cpu",
                                         pin_memory=self.pin_memory)
        self.positions_np = self.positions_cpu.numpy()
        self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
                                            dtype=torch.int32,
                                            device="cpu",
                                            pin_memory=self.pin_memory)
        self.slot_mapping_np = self.slot_mapping_cpu.numpy()
        self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
                                               dtype=torch.int32,
                                               device="cpu",
                                               pin_memory=self.pin_memory)
        self.query_start_loc_np = self.query_start_loc_cpu.numpy()
229
230
231
232
233
        self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
                                        dtype=torch.int32,
                                        device="cpu",
                                        pin_memory=self.pin_memory)
        self.seq_lens_np = self.seq_lens_cpu.numpy()
234

235
    def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
236
237
238
239
240
241
        """Update the cached states and the persistent batch with the scheduler
        output.

        The updated states are used by the `_prepare_inputs` function to create
        the input GPU tensors for the model.

242
243
        The SamplingMetadata is updated and copied to the GPU if there is a
        new/resumed/paused/finished request in the batch.
244
245
        """
        # Remove finished requests from the cached states.
246
247
        for req_id in scheduler_output.finished_req_ids:
            self.requests.pop(req_id, None)
248
            self.encoder_cache.pop(req_id, None)
249
250
251
252
253
254
255
256
257
258
259
        # Remove the finished requests from the persistent batch.
        # NOTE(woosuk): There could be an edge case where finished_req_ids and
        # scheduled_req_ids overlap. This happens when a request is aborted and
        # then resubmitted with the same ID. In this case, we treat them as two
        # distinct requests - clearing the cached states for the first request
        # and handling the second as a new request.
        removed_req_indices: List[int] = []
        for req_id in scheduler_output.finished_req_ids:
            req_index = self.input_batch.remove_request(req_id)
            if req_index is not None:
                removed_req_indices.append(req_index)
260
261
262
263
264
265
266
267

        # Free the cached encoder outputs.
        for req_id, input_id in scheduler_output.free_encoder_input_ids:
            encoder_outputs = self.encoder_cache.get(req_id)
            if encoder_outputs is not None:
                encoder_outputs.pop(input_id, None)
                if not encoder_outputs:
                    self.encoder_cache.pop(req_id, None)
268

269
270
271
272
273
274
275
276
277
278
279
280
281
        # Remove the unscheduled requests from the persistent batch.
        # NOTE(woosuk): The unscheduled requests are either preempted requests
        # or running requests that are not scheduled in this step. We remove
        # them from the persistent batch but keep their cached states since
        # they will be scheduled again sometime in the future.
        scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
        cached_req_ids = self.input_batch.req_id_to_index.keys()
        unscheduled_req_ids = cached_req_ids - scheduled_req_ids
        # NOTE(woosuk): The persistent batch optimization assumes that
        # consecutive batches contain mostly the same requests. If batches
        # have low request overlap (e.g., alternating between two distinct
        # sets of requests), this optimization becomes very inefficient.
        for req_id in unscheduled_req_ids:
282
            req_index = self.input_batch.remove_request(req_id)
283
284
            assert req_index is not None
            removed_req_indices.append(req_index)
285
286
287

        req_ids_to_add: List[str] = []
        # Add new requests to the cached states.
288
289
290
        for new_req_data in scheduler_output.scheduled_new_reqs:
            req_id = new_req_data.req_id
            sampling_params = new_req_data.sampling_params
291
            if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
292
293
294
295
296
                generator = torch.Generator(device=self.device)
                generator.manual_seed(sampling_params.seed)
            else:
                generator = None

297
298
            self.requests[req_id] = CachedRequestState(
                req_id=req_id,
299
300
301
302
                prompt_token_ids=new_req_data.prompt_token_ids,
                prompt=new_req_data.prompt,
                mm_inputs=new_req_data.mm_inputs,
                mm_positions=new_req_data.mm_positions,
303
304
                sampling_params=sampling_params,
                generator=generator,
305
306
                block_ids=new_req_data.block_ids,
                num_computed_tokens=new_req_data.num_computed_tokens,
307
                output_token_ids=[],
308
                lora_request=new_req_data.lora_request,
309
            )
310
311

            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
312
            if self.uses_mrope:
313
314
                image_grid_thw = []
                video_grid_thw = []
Roger Wang's avatar
Roger Wang committed
315
                second_per_grid_ts = []
316
317
318
319
320
321
322
                for mm_input in self.requests[req_id].mm_inputs:
                    if mm_input.get("image_grid_thw") is not None:
                        image_grid_thw.extend(
                            mm_input["image_grid_thw"].tolist())
                    if mm_input.get("video_grid_thw") is not None:
                        video_grid_thw.extend(
                            mm_input["video_grid_thw"].tolist())
Roger Wang's avatar
Roger Wang committed
323
324
325
                    if mm_input.get("second_per_grid_ts") is not None:
                        second_per_grid_ts.extend(
                            mm_input["second_per_grid_ts"])
326
327
328
329
330
331
332

                hf_config = self.model_config.hf_config

                self.requests[req_id].mrope_positions, \
                    self.requests[req_id].mrope_position_delta = \
                    MRotaryEmbedding.get_input_positions_tensor(
                        self.requests[req_id].prompt_token_ids,
Roger Wang's avatar
Roger Wang committed
333
                        hf_config=hf_config,
334
335
                        image_grid_thw=image_grid_thw,
                        video_grid_thw=video_grid_thw,
Roger Wang's avatar
Roger Wang committed
336
                        second_per_grid_ts=second_per_grid_ts,
337
338
                    )

339
340
            req_ids_to_add.append(req_id)

341
342
343
        # Update the states of the running/resumed requests.
        for req_data in scheduler_output.scheduled_cached_reqs:
            req_id = req_data.req_id
344
345
            req_state = self.requests[req_id]

346
            # Update the cached states.
347
348
349
350
351
352
353
            num_computed_tokens = req_data.num_computed_tokens
            req_state.num_computed_tokens = num_computed_tokens
            # Add the sampled token(s) from the previous step (if any).
            # This doesn't include "unverified" tokens like spec decode tokens.
            num_new_tokens = (num_computed_tokens +
                              len(req_data.new_token_ids) -
                              req_state.num_tokens)
354
355
356
357
358
359
            if num_new_tokens == 1:
                # Avoid slicing list in most common case.
                req_state.output_token_ids.append(req_data.new_token_ids[-1])
            elif num_new_tokens > 0:
                req_state.output_token_ids.extend(
                    req_data.new_token_ids[-num_new_tokens:])
360
            # Update the block IDs.
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
            if not req_data.resumed_from_preemption:
                # Append the new blocks to the existing block IDs.
                req_state.block_ids.extend(req_data.new_block_ids)
            else:
                # The request is resumed from preemption.
                # Replace the existing block IDs with the new ones.
                req_state.block_ids = req_data.new_block_ids

            req_index = self.input_batch.req_id_to_index.get(req_id)
            if req_index is None:
                # The request is not in the persistent batch.
                # The request was either preempted and resumed later, or was not
                # scheduled in the previous step and needs to be added again.
                req_ids_to_add.append(req_id)
                continue

            # Update the persistent batch.
            self.input_batch.num_computed_tokens_cpu[req_index] = (
379
380
381
                num_computed_tokens)
            start_index = (len(req_state.block_ids) -
                           len(req_data.new_block_ids))
382
383
            self.input_batch.block_table.append_row(req_index, start_index,
                                                    req_data.new_block_ids)
384
385
386
387
388
389
            # Add new_token_ids to token_ids_cpu.
            start_token_index = num_computed_tokens
            end_token_index = num_computed_tokens + len(req_data.new_token_ids)
            self.input_batch.token_ids_cpu[
                req_index,
                start_token_index:end_token_index] = req_data.new_token_ids
390
            self.input_batch.num_tokens_no_spec[req_index] = end_token_index
391
392
            # Add spec_token_ids to token_ids_cpu.
            spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
393
                req_id, ())
394
395
396
397
398
399
400
            if spec_token_ids:
                start_index = end_token_index
                end_token_index += len(spec_token_ids)
                self.input_batch.token_ids_cpu[
                    req_index, start_index:end_token_index] = spec_token_ids
            # NOTE(woosuk): `num_tokens` here may include spec decode tokens.
            self.input_batch.num_tokens[req_index] = end_token_index
401

402
403
        # Check if the batch has changed. If not, we can skip copying the
        # sampling metadata from CPU to GPU.
404
405
        batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0

406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
        # Add the new or resumed requests to the persistent batch.
        # The smaller empty indices are filled first.
        removed_req_indices = sorted(removed_req_indices, reverse=True)
        for req_id in req_ids_to_add:
            req_state = self.requests[req_id]
            if removed_req_indices:
                # Fill the empty index.
                req_index = removed_req_indices.pop()
            else:
                # Append to the end.
                req_index = None
            self.input_batch.add_request(req_state, req_index)

        # Condense the batched states if there are empty indices.
        if removed_req_indices:
            self.input_batch.condense(removed_req_indices)
422

423
424
        if batch_changed:
            self.input_batch.refresh_sampling_metadata()
425

426
    def _prepare_inputs(
427
428
        self,
        scheduler_output: "SchedulerOutput",
429
    ) -> Tuple[FlashAttentionMetadata, torch.Tensor]:
430
431
432
433
434
435
436
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        assert total_num_scheduled_tokens > 0
        num_reqs = self.input_batch.num_reqs
        assert num_reqs > 0

        # OPTIMIZATION: Start copying the block table first.
        # This way, we can overlap the copy with the following CPU operations.
437
        self.input_batch.block_table.commit(num_reqs)
438
439
440

        # Get the number of scheduled tokens for each request.
        # TODO: The Python loop can be slow. Optimize.
441
        num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
442
        max_num_scheduled_tokens = 0
443
        for i, req_id in enumerate(self.input_batch.req_ids):
444
            num_tokens = scheduler_output.num_scheduled_tokens[req_id]
445
            num_scheduled_tokens[i] = num_tokens
446
447
448
449
450
            max_num_scheduled_tokens = max(max_num_scheduled_tokens,
                                           num_tokens)

        # Get request indices.
        # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
451
452
        req_indices = np.repeat(self.arange_np[:num_reqs],
                                num_scheduled_tokens)
453
454
455

        # Get batched arange.
        # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
456
457
458
459
460
461
462
463
464
        # Equivalent to but faster than:
        # np.concatenate([np.arange(n) for n in num_scheduled_tokens])
        # Step 1. [2, 5, 3] -> [2, 7, 10]
        cu_num_tokens = np.cumsum(num_scheduled_tokens)
        # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
        cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
                                    num_scheduled_tokens)
        # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
465
466

        # Get positions.
467
        positions_np = self.positions_np[:total_num_scheduled_tokens]
468
469
470
471
        np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
               arange,
               out=positions_np)

472
473
        # Calculate M-RoPE positions.
        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
474
        if self.uses_mrope:
475
476
            self._calc_mrope_positions(scheduler_output)

477
478
479
480
        # Get token indices.
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
        # where M is the max_model_len.
481
482
        token_indices = (positions_np +
                         req_indices * self.input_batch.token_ids_cpu.shape[1])
483

484
485
486
487
        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
        torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
488
                           0,
489
490
                           torch.from_numpy(token_indices),
                           out=self.input_ids_cpu[:total_num_scheduled_tokens])
491
492

        # Calculate the slot mapping.
493
494
495
496
497
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
        # where K is the max_num_blocks_per_req and the block size is 2.
        # NOTE(woosuk): We can't simply use `token_indices // block_size` here
        # because M (max_model_len) is not necessarily divisible by block_size.
498
499
500
501
502
        block_table_indices = (req_indices * self.max_num_blocks_per_req +
                               positions_np // self.block_size)
        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
503
504
        block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
        block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
505
506
507
508
        block_offsets = positions_np % self.block_size
        np.add(block_numbers * self.block_size,
               block_offsets,
               out=self.slot_mapping_np[:total_num_scheduled_tokens])
509
510

        # Prepare the attention metadata.
511
        self.query_start_loc_np[0] = 0
512
        self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
513

514
515
516
517
        self.seq_lens_np[:num_reqs] = (
            self.input_batch.num_computed_tokens_cpu[:num_reqs] +
            num_scheduled_tokens)
        max_seq_len = self.seq_lens_np[:num_reqs].max()
518
519
520
521

        # Copy the tensors to the GPU.
        self.input_ids[:total_num_scheduled_tokens].copy_(
            self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
522
        if self.uses_mrope:
523
524
525
526
527
528
529
530
531
            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
            self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
                self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
                non_blocking=True)
        else:
            # Common case (1D positions)
            self.positions[:total_num_scheduled_tokens].copy_(
                self.positions_cpu[:total_num_scheduled_tokens],
                non_blocking=True)
532
533
        query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
            self.device, non_blocking=True)
534
535
        seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device,
                                                   non_blocking=True)
536
537
        slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
            self.device, non_blocking=True).long()
538
539

        # Prepare for cascade attention if needed.
540
541
542
543
544
        common_prefix_len = self._compute_cascade_attn_prefix_len(
            num_scheduled_tokens,
            scheduler_output.num_common_prefix_blocks,
        )
        use_cascade = common_prefix_len > 0
545
546
547
548
549
550
        if use_cascade:
            # TODO: Optimize.
            cu_prefix_query_lens = torch.tensor(
                [0, total_num_scheduled_tokens],
                dtype=torch.int32,
                device=self.device)
551
552
553
554
555
            prefix_kv_lens = torch.tensor([common_prefix_len],
                                          dtype=torch.int32,
                                          device=self.device)
            suffix_kv_lens = (self.seq_lens_np[:num_reqs] - common_prefix_len)
            suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(self.device)
556
557
        else:
            cu_prefix_query_lens = None
558
559
            prefix_kv_lens = None
            suffix_kv_lens = None
560

561
        attn_metadata = FlashAttentionMetadata(
562
            num_actual_tokens=total_num_scheduled_tokens,
563
564
565
            max_query_len=max_num_scheduled_tokens,
            query_start_loc=query_start_loc,
            max_seq_len=max_seq_len,
566
            seq_lens=seq_lens,
567
568
            block_table=(
                self.input_batch.block_table.get_device_tensor()[:num_reqs]),
569
            slot_mapping=slot_mapping,
570
571
572
            use_cascade=use_cascade,
            common_prefix_len=common_prefix_len,
            cu_prefix_query_lens=cu_prefix_query_lens,
573
574
            prefix_kv_lens=prefix_kv_lens,
            suffix_kv_lens=suffix_kv_lens,
575
        )
576

577
578
        use_spec_decode = len(
            scheduler_output.scheduled_spec_decode_tokens) > 0
579
        if use_spec_decode:
580
581
            logits_indices = self._calc_spec_decode_metadata(
                scheduler_output, cu_num_tokens)
582
583
584
585
586
587
588
589
        else:
            # NOTE(woosuk): Due to chunked prefills, the batch may contain
            # partial requests. While we should not sample any token
            # from these partial requests, we do so for simplicity.
            # We will ignore the sampled tokens from the partial requests.
            # TODO: Support prompt logprobs.
            logits_indices = query_start_loc[1:] - 1

590
591
592
593
        # Hot-Swap lora model
        if self.lora_config:
            self.set_active_loras(self.input_batch, num_scheduled_tokens)

594
        return attn_metadata, logits_indices
595

596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
    def _compute_cascade_attn_prefix_len(
        self,
        num_scheduled_tokens: np.ndarray,
        num_common_prefix_blocks: int,
    ) -> int:
        """Compute the length of the common prefix for cascade attention.

        NOTE(woosuk): The common prefix length returned by this function
        represents the length used specifically for cascade attention, not the
        actual number of tokens shared between requests. When cascade attention
        is disabled (use_cascade=False), this function returns 0 even if
        requests share common tokens. Additionally, the common prefix length is
        truncated to a multiple of the block size and may be further truncated
        due to implementation details explained below.

        Args:
            num_scheduled_tokens: Number of tokens scheduled per request.
            num_common_prefix_blocks: Number of shared KV cache blocks.

        Returns:
            int: Length of common prefix in tokens.
        """
        common_prefix_len = num_common_prefix_blocks * self.block_size
        if common_prefix_len == 0:
            # Common case.
            return 0

        # NOTE(woosuk): Cascade attention uses two attention kernels: one
        # for the common prefix and the other for the rest. For the first
        # kernel, we concatenate all the query tokens (possibly from
        # different requests) and treat them as if they are from the same
        # request. Then, we use bi-directional attention to process the
        # common prefix in the KV cache. Importantly, this means that the
        # first kernel does not do any masking.

        # Consider the following example:
        # Request 1's input query: [D, E, X]
        # Request 1's kv cache: [A, B, C, D, E, X]
        # Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
        # Request 2's input query: [E, Y]
        # Request 2's kv cache: [A, B, C, D, E, Y]
        # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])

        # If we use [A, B, C, D, E] as the common prefix, then the
        # first kernel will compute the bi-directional attention between
        # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
        # However, this is wrong because D in Request 1 should not attend to
        # E in the common prefix (i.e., we need masking).
        # To avoid this, [A, B, C, D] should be the common prefix.
        # That is, the common prefix should be capped by the minimum
        # num_computed_tokens among the requests, and plus one to include
        # the first token of the query.

        # In practice, we use [A, B, C] as the common prefix, instead of
        # [A, B, C, D] (i.e., the common prefix is capped by the minimum
        # num_computed_tokens, without plus one).
        # This is because of an implementation detail: We want to always
        # use two kernels for cascade attention. Let's imagine:
        # Request 3's input query: [D]
        # Request 3's kv cache: [A, B, C, D]
        # Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D])
        # If we use [A, B, C, D] as the common prefix for Request 1-3,
        # then Request 3 will be processed only by the first kernel,
        # and the second kernel will get an empty input. While this is not
        # a fundamental problem, our current implementation does not support
        # this case.
        num_reqs = len(num_scheduled_tokens)
        common_prefix_len = min(
            common_prefix_len,
            self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
        # common_prefix_len should be a multiple of the block size.
        common_prefix_len = (common_prefix_len // self.block_size *
                             self.block_size)
        use_cascade = FlashAttentionBackend.use_cascade_attention(
            common_prefix_len=common_prefix_len,
            query_lens=num_scheduled_tokens,
            num_query_heads=self.num_query_heads,
            num_kv_heads=self.num_kv_heads,
            use_alibi=False,  # FIXME
            use_sliding_window=self.sliding_window is not None,
            num_sms=self.num_sms,
        )
        return common_prefix_len if use_cascade else 0

680
681
    def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
        mrope_pos_ptr = 0
682
        for index, req_id in enumerate(self.input_batch.req_ids):
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
            req = self.requests[req_id]
            assert req.mrope_positions is not None

            num_computed_tokens = \
                self.input_batch.num_computed_tokens_cpu[index]
            num_scheduled_tokens = \
                scheduler_output.num_scheduled_tokens[req_id]
            num_prompt_tokens = len(req.prompt_token_ids)

            if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
                prompt_part_len = max(0,
                                      num_prompt_tokens - num_computed_tokens)
                completion_part_len = max(
                    0, num_scheduled_tokens - prompt_part_len)
            else:
                prompt_part_len = num_scheduled_tokens
                completion_part_len = 0

            assert num_scheduled_tokens == prompt_part_len + completion_part_len

            if prompt_part_len > 0:
                # prompt's mrope_positions are pre-computed
                dst_start = mrope_pos_ptr
                dst_end = mrope_pos_ptr + prompt_part_len
                src_start = num_computed_tokens
                src_end = num_computed_tokens + prompt_part_len

                self.mrope_positions_cpu[:, dst_start:dst_end] = \
                    req.mrope_positions[:,src_start:src_end]

                mrope_pos_ptr += prompt_part_len

            if completion_part_len > 0:
                # compute completion's mrope_positions on-the-fly
                dst_start = mrope_pos_ptr
                dst_end = mrope_pos_ptr + completion_part_len

                self.mrope_positions_cpu[:, dst_start:dst_end] = \
                    MRotaryEmbedding.get_next_input_positions_tensor(
                        req.mrope_position_delta,
                        context_len=num_computed_tokens +
                        prompt_part_len,
                        seq_len=num_computed_tokens +
                        prompt_part_len +
                        completion_part_len,
                    )

                mrope_pos_ptr += completion_part_len

732
733
734
735
    def _calc_spec_decode_metadata(
        self,
        scheduler_output: "SchedulerOutput",
        cu_num_tokens: np.ndarray,
736
    ) -> torch.Tensor:
737
738
739
        # Get the number of spec decode tokens for each request.
        num_reqs = self.input_batch.num_reqs
        num_spec_decode_tokens = np.empty(num_reqs, dtype=np.int32)
740
        for i, req_id in enumerate(self.input_batch.req_ids):
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
            num_spec_decode_tokens[i] = len(
                scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))

        # Get spec decode logits indices.
        # E.g.,   num_scheduled_tokens: [4, 100, 3,   100, 2]
        #         cu_num_tokens:        [4, 104, 107, 207, 209]
        #         num_spec_tokens_list: [3, 0,   2,   0,   1]
        #         num_sampled_tokens:   [4, 1,   3,   1,   2]
        #         spec_decode_logits_indices:
        #                 [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
        num_sampled_tokens = num_spec_decode_tokens + 1
        # logits_start_loc: [0, 103, 104, 206, 207]
        logits_start_loc = cu_num_tokens - num_sampled_tokens
        # [0, 103, 104, 206, 207] ->
        #               [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
        logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens)
        # The following three lines:
        # [4, 1,   3,   1,   2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
        # Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11]
        cu_num_sampled_tokens = np.cumsum(num_sampled_tokens)
        # Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9]
        #         -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
        cumsums_sampled_offsets = np.repeat(
            cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens)
        # Step 3.  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
        #       -  [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
        #      -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
        total_num_sampled_tokens = num_sampled_tokens.sum()
        sampled_arange = (self.arange_np[:total_num_sampled_tokens] -
                          cumsums_sampled_offsets)

        # [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] ->
        # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
        spec_decode_logits_indices = logits_start_loc + sampled_arange
        return torch.from_numpy(spec_decode_logits_indices).to(
            self.device, non_blocking=True)

778
779
780
781
782
783
784
    def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
        scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
        if not scheduled_encoder_inputs:
            return

        # Batch the multi-modal inputs.
        mm_inputs: List[MultiModalKwargs] = []
785
        req_input_ids: List[Tuple[str, int]] = []
786
787
788
789
790
        for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
            req_state = self.requests[req_id]
            for input_id in encoder_input_ids:
                mm_inputs.append(req_state.mm_inputs[input_id])
                req_input_ids.append((req_id, input_id))
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818

        # Batch mm inputs as much as we can: if a request in the batch has
        # multiple modalities or a different modality than the previous one,
        # we process it separately to preserve item order.
        # FIXME(ywang96): This is a hacky way to deal with multiple modalities
        # in the same batch while still being able to benefit from batching
        # multimodal inputs. The proper solution should be reordering the
        # encoder outputs.
        grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)

        encoder_outputs = []
        for grouped_mm_inputs in grouped_mm_inputs_list:
            batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
            batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
                                                           device=self.device)

            # Run the encoder.
            # `curr_group_outputs` is either of the following:
            # 1. A tensor of shape (num_items, feature_size, hidden_size)
            # in case feature_size is fixed across all multimodal items.
            # 2. A list or tuple (length: num_items) of tensors, each of shape
            # (feature_size, hidden_size) in case the feature size is dynamic
            # depending on the input multimodal items.
            curr_group_outputs = self.model.get_multimodal_embeddings(
                **batched_mm_inputs)

            for output in curr_group_outputs:
                encoder_outputs.append(output)
819
820
821
822
823
824
825
826
827
828
829
830

        # Cache the encoder outputs.
        for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
            if req_id not in self.encoder_cache:
                self.encoder_cache[req_id] = {}
            self.encoder_cache[req_id][input_id] = output

    def _gather_encoder_outputs(
        self,
        scheduler_output: "SchedulerOutput",
    ) -> List[torch.Tensor]:
        encoder_outputs: List[torch.Tensor] = []
831
        for req_id in self.input_batch.req_ids:
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
                req_id]
            req_state = self.requests[req_id]
            num_computed_tokens = req_state.num_computed_tokens
            mm_positions = req_state.mm_positions
            for i, pos_info in enumerate(mm_positions):
                start_pos = pos_info["offset"]
                num_encoder_tokens = pos_info["length"]

                # The encoder output is needed if the two ranges overlap:
                # [num_computed_tokens,
                #  num_computed_tokens + num_scheduled_tokens) and
                # [start_pos, start_pos + num_encoder_tokens)
                if start_pos >= num_computed_tokens + num_scheduled_tokens:
                    # The encoder output is not needed in this step.
                    break
                if start_pos + num_encoder_tokens <= num_computed_tokens:
                    # The encoder output is already processed and stored
                    # in the decoder's KV cache.
                    continue

                start_idx = max(num_computed_tokens - start_pos, 0)
                end_idx = min(
                    num_computed_tokens - start_pos + num_scheduled_tokens,
                    num_encoder_tokens)
                assert start_idx < end_idx
                assert req_id in self.encoder_cache
                assert i in self.encoder_cache[req_id]
                encoder_output = self.encoder_cache[req_id][i]
                encoder_outputs.append(encoder_output[start_idx:end_idx])
        return encoder_outputs

864
865
866
    def get_model(self) -> nn.Module:
        return self.model

867
868
869
870
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
871
        intermediate_tensors: Optional[IntermediateTensors] = None,
872
    ) -> Union[ModelRunnerOutput, torch.Tensor]:
873
        self._update_states(scheduler_output)
874

875
876
877
878
879
880
        if self.is_multimodal_model:
            # Run the multimodal encoder if any.
            self._execute_encoder(scheduler_output)
            encoder_outputs = self._gather_encoder_outputs(scheduler_output)
        else:
            encoder_outputs = []
881
882

        # Prepare the decoder inputs.
883
        attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
884
885
886
887
888
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        if (self.use_cuda_graph
                and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
            # Use piecewise CUDA graphs.
            # Add padding to the batch size.
889
            num_input_tokens = self.vllm_config.pad_for_cudagraph(
890
891
892
893
                num_scheduled_tokens)
        else:
            # Eager mode.
            num_input_tokens = num_scheduled_tokens
894
895
        attn_metadata.num_input_tokens = num_input_tokens

896
897
898
899
900
901
902
903
904
905
906
907
908
909
        if self.is_multimodal_model:
            # NOTE(woosuk): To unify token ids and soft tokens (vision
            # embeddings), we always use embeddings (rather than token ids)
            # as input to the multimodal model, even when the input is text.
            input_ids = self.input_ids[:num_scheduled_tokens]
            if encoder_outputs:
                inputs_embeds = self.model.get_input_embeddings(
                    input_ids, encoder_outputs)
            else:
                inputs_embeds = self.model.get_input_embeddings(input_ids)
            # TODO(woosuk): Avoid the copy. Optimize.
            self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
            input_ids = None
910
        else:
911
912
913
914
915
916
            # For text-only models, we use token ids as input.
            # While it is possible to use embeddings as input just like the
            # multimodal models, it is not desirable for performance since
            # then the embedding layer is not included in the CUDA graph.
            input_ids = self.input_ids[:num_input_tokens]
            inputs_embeds = None
917
918
919
920
        if self.uses_mrope:
            positions = self.mrope_positions[:, :num_input_tokens]
        else:
            positions = self.positions[:num_input_tokens]
921

922
923
924
        if get_pp_group().is_first_rank:
            intermediate_tensors = None
        else:
925
926
927
928
929
            assert intermediate_tensors is not None
            assert self.intermediate_tensors is not None
            for k, v in intermediate_tensors.items():
                self.intermediate_tensors[k][:num_input_tokens].copy_(
                    v[:num_input_tokens], non_blocking=True)
930
931
932
933
934
            intermediate_tensors = IntermediateTensors({
                k: v[:num_input_tokens]
                for k, v in self.intermediate_tensors.items()
            })

935
936
        # Run the decoder.
        # Use persistent buffers for CUDA graphs.
937
        with set_forward_context(attn_metadata, self.vllm_config):
938
            hidden_states = self.model(
939
                input_ids=input_ids,
940
                positions=positions,
941
                kv_caches=self.kv_caches,
942
                attn_metadata=None,
943
                intermediate_tensors=intermediate_tensors,
944
                inputs_embeds=inputs_embeds,
945
            )
946
        if not get_pp_group().is_last_rank:
947
            # For mid-pipeline stages, return the hidden states.
948
            return hidden_states
949

950
        hidden_states = hidden_states[:num_scheduled_tokens]
951
952
        sample_hidden_states = hidden_states[logits_indices]
        logits = self.model.compute_logits(sample_hidden_states, None)
953
954

        # Sample the next token and get logprobs if needed.
955
956
        sampling_metadata = self.input_batch.get_sampling_metadata(
            scheduler_output.scheduled_spec_decode_tokens)
957
958
959
960
961
962
963
        sampler_output = self.model.sample(
            logits=logits,
            sampling_metadata=sampling_metadata,
        )

        # TODO(woosuk): The following loop can be slow since it iterates over
        # the requests one by one. Optimize.
964
        for i, req_id in enumerate(self.input_batch.req_ids):
965
966
967
            req_state = self.requests[req_id]
            seq_len = (req_state.num_computed_tokens +
                       scheduler_output.num_scheduled_tokens[req_id])
968
969
            if seq_len < req_state.num_tokens:
                # Ignore the sampled token.
970
                # Rewind the generator state as if the token was not sampled.
971
                generator = self.input_batch.generators.get(i)
972
                if generator is not None:
973
974
                    # This relies on cuda-specific torch-internal impl details
                    generator.set_offset(generator.get_offset() - 4)
975

976
977
        # NOTE: GPU -> CPU Sync happens here.
        # Move as many CPU operations as possible before this sync point.
978
979
980
981
982
983
984
985
986
987
        logprobs_tensors = sampler_output.logprobs_tensors
        logprobs_lists = logprobs_tensors.tolists() \
            if logprobs_tensors is not None else None

        # Compute prompt logprobs if needed.
        prompt_logprobs_dict = self._get_prompt_logprobs_dict(
            hidden_states,
            scheduler_output,
        )

988
        # Get the valid generated tokens.
989
990
991
        sampled_token_ids = sampler_output.sampled_token_ids
        max_gen_len = sampled_token_ids.shape[-1]
        if max_gen_len == 1:
992
            # No spec decode tokens.
993
994
            valid_sampled_token_ids = sampled_token_ids.tolist()
        else:
995
            # Includes spec decode tokens.
996
997
            valid_mask = sampled_token_ids != INVALID_TOKEN_ID
            gen_lens = valid_mask.sum(dim=1).tolist()
998
            # TODO(woosuk): Optimize this.
999
1000
1001
1002
            valid_sampled_token_ids = [
                seq.tolist()
                for seq in sampled_token_ids[valid_mask].split(gen_lens)
            ]
1003

1004
1005
1006
1007
1008
1009
        if not self.use_spec_decode:
            spec_token_ids = None
        else:
            spec_token_ids = self.generate_draft_token_ids(
                valid_sampled_token_ids)

1010
        model_runner_output = ModelRunnerOutput(
1011
            req_ids=self.input_batch.req_ids,
1012
            req_id_to_index=self.input_batch.req_id_to_index,
1013
            sampled_token_ids=valid_sampled_token_ids,
1014
            spec_token_ids=spec_token_ids,
1015
1016
            logprobs=logprobs_lists,
            prompt_logprobs_dict=prompt_logprobs_dict,
1017
1018
1019
        )
        return model_runner_output

1020
1021
1022
1023
1024
1025
    def generate_draft_token_ids(
        self,
        sampled_token_ids: List[List[int]],
    ) -> List[List[int]]:
        # TODO(woosuk): Optimize.
        draft_token_ids: List[List[int]] = []
1026
1027
1028
        for i, sampled_ids in enumerate(sampled_token_ids):
            num_sampled_ids = len(sampled_ids)
            if not num_sampled_ids:
1029
1030
1031
1032
1033
1034
                # Skip speculative decoding.
                draft_token_ids.append([])
                continue

            # Add sampled_token_ids to token_ids_cpu.
            start_idx = self.input_batch.num_tokens_no_spec[i]
1035
1036
            end_idx = start_idx + num_sampled_ids
            self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
            drafter_output = self.drafter.propose(
                self.input_batch.token_ids_cpu[i, :end_idx],
                self.speculative_config.ngram_prompt_lookup_min,
                self.speculative_config.num_speculative_tokens,
            )
            if drafter_output is None or len(drafter_output) == 0:
                draft_token_ids.append([])
            else:
                draft_token_ids.append(drafter_output.tolist())
        return draft_token_ids

1048
1049
1050
    def load_model(self) -> None:
        logger.info("Starting to load model %s...", self.model_config.model)
        with DeviceMemoryProfiler() as m:  # noqa: SIM117
1051
            time_before_load = time.perf_counter()
Joe Runde's avatar
Joe Runde committed
1052
            self.model = get_model(vllm_config=self.vllm_config)
1053
1054
1055
1056
1057
1058
            if self.lora_config:
                self.model = self.load_lora_model(self.model,
                                                  self.model_config,
                                                  self.scheduler_config,
                                                  self.lora_config,
                                                  self.device)
1059
            time_after_load = time.perf_counter()
1060
        self.model_memory_usage = m.consumed_memory
1061
1062
1063
        logger.info("Loading model weights took %.4f GB and %.6f seconds",
                    self.model_memory_usage / float(2**30),
                    time_after_load - time_before_load)
1064

1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
    def _get_prompt_logprobs_dict(
        self,
        hidden_states: torch.Tensor,
        scheduler_output: "SchedulerOutput",
    ) -> Dict[str, LogprobsTensors]:
        num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
        if not num_prompt_logprobs_dict:
            return {}

        prompt_logprobs_dict: Dict[str, LogprobsTensors] = {}

        # Since prompt logprobs are a rare feature, prioritize simple,
        # maintainable loop over optimal performance.
        completed_prefill_reqs = []
        for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():

            num_tokens = scheduler_output.num_scheduled_tokens[req_id]

            # Get metadata for this request.
            request = self.requests[req_id]
            num_prompt_tokens = len(request.prompt_token_ids)
            prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
                self.device, non_blocking=True)

            # Determine number of logits to retrieve.
            start_tok = request.num_computed_tokens + 1
            num_remaining_tokens = num_prompt_tokens - start_tok
            if num_tokens < num_remaining_tokens:
                # This is a chunk, more tokens remain.
                num_logits = num_tokens
            else:
                # This is the last chunk of prompt tokens to return.
                num_logits = num_remaining_tokens
                completed_prefill_reqs.append(req_id)

            # Get the logits corresponding to this req's prompt tokens.
            # If this is a partial request (i.e. chunked prefill),
            # then there is prompt logprob generated for each index.
            req_idx = self.input_batch.req_id_to_index[req_id]
            offset = self.query_start_loc_np[req_idx].item()
            prompt_hidden_states = hidden_states[offset:offset + num_logits]
            logits = self.model.compute_logits(prompt_hidden_states, None)

            # Get the "target" tokens for each index. For prompt at index i,
            # the token at prompt index i+1 is the "sampled" token we want
            # to gather the logprob for.
            tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]

            # Compute prompt logprobs.
            logprobs = self.model.sampler.compute_logprobs(logits)
            token_ids, logprobs, ranks = self.model.sampler.gather_logprobs(
                logprobs, num_prompt_logprobs, tgt_token_ids)

            # Transfer GPU->CPU async.
            prompt_logprobs_dict[req_id] = LogprobsTensors(
                token_ids.to("cpu", non_blocking=True),
                logprobs.to("cpu", non_blocking=True),
                ranks.to("cpu", non_blocking=True),
            )

        # Remove requests that have completed prefill from the batch
        # num_prompt_logprobs_dict.
        for req_id in completed_prefill_reqs:
            del num_prompt_logprobs_dict[req_id]

        # Must synchronize the non-blocking GPU->CPU transfers.
        torch.cuda.synchronize()

        return prompt_logprobs_dict

1135
1136
1137
1138
    @torch.inference_mode()
    def _dummy_run(
        self,
        num_tokens: int,
1139
        kv_caches: Optional[List[torch.Tensor]] = None,
1140
    ) -> torch.Tensor:
1141
1142
1143
        model = self.model
        if kv_caches is None:
            kv_caches = self.kv_caches
1144
1145
1146
1147
1148
1149
        if self.is_multimodal_model:
            input_ids = None
            inputs_embeds = self.inputs_embeds[:num_tokens]
        else:
            input_ids = self.input_ids[:num_tokens]
            inputs_embeds = None
1150
1151
1152
1153
        if self.uses_mrope:
            positions = self.mrope_positions[:, :num_tokens]
        else:
            positions = self.positions[:num_tokens]
1154
1155
1156
1157

        if get_pp_group().is_first_rank:
            intermediate_tensors = None
        else:
1158
            if self.intermediate_tensors is None:
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
                self.intermediate_tensors = (
                    self.model.make_empty_intermediate_tensors(
                        batch_size=self.max_num_tokens,
                        dtype=self.model_config.dtype,
                        device=self.device))
            intermediate_tensors = IntermediateTensors({
                k: v[:num_tokens]
                for k, v in self.intermediate_tensors.items()
            })

1169
        with set_forward_context(None, self.vllm_config):
1170
            hidden_states = model(
1171
                input_ids=input_ids,
1172
                positions=positions,
1173
1174
                kv_caches=kv_caches,
                attn_metadata=None,
1175
                intermediate_tensors=intermediate_tensors,
1176
1177
                inputs_embeds=inputs_embeds,
            )
1178
1179
1180
        return hidden_states

    def profile_run(self) -> None:
1181
1182
1183
1184
1185
1186
1187
1188
        # use an empty tensor instead of `None`` to force Dynamo to pass
        # it by reference, rather by specializing on the value `None`.
        # the `dtype` argument does not matter, and we use `float32` as
        # a placeholder (it has wide hardware support).
        # it is important to create tensors inside the loop, rather than
        # multiplying the list, to avoid Dynamo from treating them as
        # tensor aliasing.
        dummy_kv_caches = [
1189
            torch.tensor((), dtype=torch.float32, device=self.device)
1190
1191
            for _ in range(self.num_attn_layers)
        ]
1192
1193

        # Profile with multimodal encoder & encoder cache.
1194
1195
1196
        # TODO: handle encoder-decoder models once we support them.
        if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
                and self.encoder_cache_size > 0):
1197

1198
            # NOTE: Currently model is profiled with a single non-text
1199
1200
            # modality with the max possible input tokens even when
            # it supports multiple.
1201
            max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality(  # noqa: E501
1202
1203
1204
1205
1206
                self.model_config)
            dummy_data_modality, max_tokens_per_mm_item = max(
                max_tokens_by_modality_dict.items(), key=lambda item: item[1])

            # Check how many items of this modality can be supported by
1207
1208
1209
1210
1211
1212
            # the encoder budget.
            encoder_budget = min(self.max_num_encoder_input_tokens,
                                 self.encoder_cache_size)

            max_num_mm_items_encoder_budget = cdiv(encoder_budget,
                                                   max_tokens_per_mm_item)
1213
1214
1215

            # Check how many items of this modality can be supported by
            # the decoder budget.
1216
1217
            max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
                self.model_config)[dummy_data_modality]
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227

            # NOTE: We do not consider max_num_batched_tokens on purpose
            # because the multimodal embeddings can be generated in advance
            # and chunked prefilled.
            max_num_mm_items_decoder_budget = self.max_num_reqs * \
                max_mm_items_per_req

            max_num_mm_items = min(max_num_mm_items_encoder_budget,
                                   max_num_mm_items_decoder_budget)

1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
            logger.info(
                "Encoder cache will be initialized with a budget of %s tokens,"
                " and profiled with %s %s items of the maximum feature size.",
                encoder_budget, max_num_mm_items, dummy_data_modality)

            # Create dummy batch of multimodal inputs.
            dummy_request_data = self.input_registry.dummy_data_for_profiling(
                model_config=self.model_config,
                seq_len=self.max_num_tokens,
                mm_registry=self.mm_registry,
            )
            dummy_mm_data = dummy_request_data.multi_modal_data

1241
1242
1243
1244
            # Dummy data definition in V0 may contain multiple multimodal items
            # (e.g, multiple images) for a single request, therefore here we
            # always replicate first item by max_num_mm_items times since in V1
            # they are scheduled to be processed separately.
1245
1246

            # Case when models have a merged processor, their dummy data is
1247
1248
            # already batched `MultiModalKwargs`, therefore we take the first
            # `MultiModalKwargsItem` from the desired modality to profile on.
1249
            if isinstance(dummy_mm_data, MultiModalKwargs):
1250
1251
1252
                dummy_mm_item = dummy_mm_data.get_item(
                    modality=dummy_data_modality, item_index=0)
                dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
1253
1254
1255
1256

            # Case when models have dummy data explicitly defined as
            # `MultiModalDataDict`, so they need to be processed through input
            # mapper.
1257
1258
            # TODO (ywang96): deprecate this path once merged processor is
            # supported on all models.
1259
            else:
1260
                mm_kwargs_list = self.mm_input_mapper_profiling.process_inputs(
1261
                    mm_data=dummy_mm_data,
1262
                    mm_hashes=None,
1263
1264
1265
1266
                    mm_processor_kwargs=None,
                    precomputed_mm_inputs=None)
                dummy_mm_kwargs = mm_kwargs_list[0]

1267
            batched_dummy_mm_inputs = MultiModalKwargs.batch(
1268
                [dummy_mm_kwargs] * max_num_mm_items)
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
            batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
                batched_dummy_mm_inputs, device=self.device)

            # Run multimodal encoder.
            dummy_encoder_outputs = self.model.get_multimodal_embeddings(
                **batched_dummy_mm_inputs)
            assert len(dummy_encoder_outputs) == max_num_mm_items, (
                "Expected dimension 0 of encoder outputs to match the number "
                f"of multimodal data items: {max_num_mm_items}, got "
                f"{len(dummy_encoder_outputs)=} instead. This is most likely "
                "due to the 'get_multimodal_embeddings' method of the model "
                "not implemented correctly.")

            # Cache the dummy encoder outputs.
            self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))

1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
        # For profile, have maximum num_reqs and that collectively have
        # maximum num_tokens.
        num_reqs = self.scheduler_config.max_num_seqs
        num_tokens = self.max_num_tokens
        min_tokens_per_req: int = num_tokens // num_reqs

        num_scheduled_tokens_list: List[int] = [min_tokens_per_req] * num_reqs
        num_scheduled_tokens_list[-1] += num_tokens % num_reqs
        assert sum(num_scheduled_tokens_list) == num_tokens
        assert len(num_scheduled_tokens_list) == num_reqs

        num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list,
                                                    dtype=np.int32)
        logit_indices = np.cumsum(num_scheduled_tokens) - 1

        with self.maybe_profile_with_lora(self.lora_config,
                                          num_scheduled_tokens):
            # Trigger compilation for general shape.
            hidden_states = self._dummy_run(self.max_num_tokens,
                                            dummy_kv_caches)
1305
1306
1307
1308
1309
1310
            if get_pp_group().is_last_rank:
                hidden_states = hidden_states[logit_indices]
                logits = self.model.compute_logits(hidden_states, None)
                # TODO(woosuk): Consider the memory usage of the sampler.
            else:
                logits = None
1311
1312
1313
            torch.cuda.synchronize()
            del hidden_states, logits
            self.encoder_cache.clear()
1314
        gc.collect()
1315
1316

    def capture_model(self) -> None:
1317
1318
        if not self.use_cuda_graph:
            logger.warning(
1319
                "Skipping CUDA graph capture. Please add "
1320
                "-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE)
1321
1322
1323
1324
1325
            return

        start_time = time.perf_counter()
        start_free_gpu_memory = torch.cuda.mem_get_info()[0]

1326
1327
1328
        # Trigger CUDA graph capture for specific shapes.
        # Capture the large shapes first so that the smaller shapes
        # can reuse the memory pool allocated for the large shapes.
1329
        with graph_capture(device=self.device):
1330
            for num_tokens in reversed(self.cudagraph_batch_sizes):
1331
1332
                for _ in range(self.vllm_config.compilation_config.
                               cudagraph_num_of_warmups):
1333
1334
                    self._dummy_run(num_tokens)
                self._dummy_run(num_tokens)
1335
1336
1337
1338
1339
1340
1341
1342

        end_time = time.perf_counter()
        end_free_gpu_memory = torch.cuda.mem_get_info()[0]
        elapsed_time = end_time - start_time
        cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
        # This usually takes 5~20 seconds.
        logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
                    elapsed_time, cuda_graph_size / (1 << 30))
1343

1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize KV cache based on `kv_cache_config`.
        Args:
            kv_cache_config: Configuration for the KV cache, including the KV 
            cache size of each layer
        """
        if len(kv_cache_config.groups) > 1:
            raise NotImplementedError(
                "Hybrid models with more than one KV cache type are not "
                "supported yet.")

        kv_caches: Dict[str, torch.Tensor] = {}

        for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
            tensor_config = kv_cache_config.tensors[layer_name]
            assert tensor_config.size % layer_spec.page_size_bytes == 0
            num_blocks = tensor_config.size // layer_spec.page_size_bytes
            if isinstance(layer_spec, FullAttentionSpec):
                kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
                    num_blocks, layer_spec.block_size, layer_spec.num_kv_heads,
                    layer_spec.head_size)
                dtype = layer_spec.dtype
                kv_caches[layer_name] = torch.zeros(kv_cache_shape,
                                                    dtype=dtype,
                                                    device=self.device)
            else:
                raise NotImplementedError

1373
        bind_kv_cache(
1374
            kv_caches,
1375
            self.vllm_config.compilation_config.static_forward_context,
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
            self.kv_caches)

    def get_kv_cache_spec(self) -> KVCacheSpec:
        """
        Generates the KVCacheSpec by parsing the kv cache format from each 
        Attention module in the static forward context.
        Returns:
            KVCacheSpec: A dictionary mapping layer names to their KV cache 
            format. Layers that do not need KV cache are not included.
        """

        forward_ctx = self.vllm_config.compilation_config.static_forward_context
        block_size = self.vllm_config.cache_config.block_size
        kv_cache_spec: KVCacheSpec = {}
        for layer_name, attn_module in forward_ctx.items():
            # TODO: Support other attention modules, e.g., sliding window,
            # cross-attention, MLA.
            assert isinstance(attn_module, Attention)
            if attn_module.attn_type == AttentionType.DECODER:
                kv_cache_spec[layer_name] = FullAttentionSpec(
                    block_size=block_size,
                    num_kv_heads=attn_module.num_kv_heads,
                    head_size=attn_module.head_size,
                    dtype=attn_module.dtype,
                )
            elif attn_module.attn_type in (AttentionType.ENCODER,
                                           AttentionType.ENCODER_ONLY):
                # encoder-only attention does not need KV cache.
                continue
            elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
                raise NotImplementedError
            else:
                raise ValueError(
                    f"Unknown attention type: {attn_module.attn_type}")

        return kv_cache_spec