gpu_model_runner.py 64.4 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import gc
4
import time
5
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
6
7
8
9
10
11

import numpy as np
import torch
import torch.distributed
import torch.nn as nn

12
13
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
14
from vllm.config import CompilationLevel, VllmConfig
15
from vllm.distributed.parallel_state import get_pp_group, graph_capture
16
from vllm.forward_context import set_forward_context
17
from vllm.inputs import INPUT_REGISTRY
18
from vllm.logger import init_logger
19
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
20
from vllm.model_executor.model_loader import get_model
21
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
22
from vllm.multimodal.utils import group_mm_inputs_by_modality
23
from vllm.sampling_params import SamplingType
24
from vllm.sequence import IntermediateTensors
25
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
26
                        LayerBlockType, cdiv, is_pin_memory_available)
27
28
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
                                                   FlashAttentionMetadata)
29
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
30
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
31
32
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
                                        KVCacheSpec)
33
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
34
from vllm.v1.sample.metadata import SamplingMetadata
35
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID
36
from vllm.v1.utils import bind_kv_cache
37
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
38
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
39
40

if TYPE_CHECKING:
41
    from vllm.v1.core.scheduler_output import SchedulerOutput
42
43
44
45

logger = init_logger(__name__)


46
class GPUModelRunner(LoRAModelRunnerMixin):
47
48
49

    def __init__(
        self,
50
        vllm_config: VllmConfig,
51
        device: torch.device,
52
    ):
53
54
55
56
57
58
59
60
61
62
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.prompt_adapter_config = vllm_config.prompt_adapter_config
        self.observability_config = vllm_config.observability_config
63

64
65
66
67
        model_config = self.model_config
        cache_config = self.cache_config
        scheduler_config = self.scheduler_config
        parallel_config = self.parallel_config
68
        self.device = device
69
70
71
72
73
74
75
76
        self.pin_memory = is_pin_memory_available()
        self.dtype = self.model_config.dtype
        if cache_config.cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        else:
            self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
                cache_config.cache_dtype]

77
        self.is_multimodal_model = model_config.is_multimodal_model
78
79
80
81
82
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.max_model_len = model_config.max_model_len
        self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
        self.max_num_tokens = scheduler_config.max_num_batched_tokens
83
        self.max_num_reqs = scheduler_config.max_num_seqs
84
85

        # Model-related.
86
87
        self.num_attn_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)
88
89
        self.num_query_heads = model_config.get_num_attention_heads(
            parallel_config)
90
91
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
        self.head_size = model_config.get_head_size()
92
93
94
        self.hidden_size = model_config.get_hidden_size()

        # Multi-modal data support
95
96
        self.input_registry = INPUT_REGISTRY
        self.mm_registry = MULTIMODAL_REGISTRY
97
        self.uses_mrope = model_config.uses_mrope
98

99
        # NOTE: Initialized client is only used for processing dummy
100
        # multimodal data into multimodal kwargs for GPU memory profiling.
101
102
        # Only applicable to multimodal models with legacy input mapper.
        self.mm_input_mapper_profiling = MMInputCacheClient(self.model_config)
103
        self.mm_input_mapper_profiling.use_cache = False
104

105
106
107
108
109
110
        encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
            model_config=model_config,
            scheduler_config=scheduler_config,
        )
        self.max_num_encoder_input_tokens = encoder_compute_budget
        self.encoder_cache_size = encoder_cache_size
111
112
113
114

        # Lazy initialization
        # self.model: nn.Module  # Set after load_model
        self.kv_caches: List[torch.Tensor] = []
115
116
        # req_id -> (input_id -> encoder_output)
        self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
117
118
119
120
121

        # Request states.
        self.requests: Dict[str, CachedRequestState] = {}
        # Persistent batch.
        self.input_batch = InputBatch(
122
            max_num_reqs=self.max_num_reqs,
123
124
125
126
            max_model_len=self.max_model_len,
            max_num_blocks_per_req=self.max_num_blocks_per_req,
            device=self.device,
            pin_memory=self.pin_memory,
127
            vocab_size=model_config.get_vocab_size(),
128
129
        )

130
        self.use_cuda_graph = (self.vllm_config.compilation_config.level
131
132
133
                               == CompilationLevel.PIECEWISE
                               and not self.model_config.enforce_eager)
        # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
134
135
136
137
        # The convention is different.
        # self.cudagraph_batch_sizes sorts in ascending order.
        # The batch sizes in the config are in descending order.
        self.cudagraph_batch_sizes = list(
138
139
            reversed(
                self.vllm_config.compilation_config.cudagraph_capture_sizes))
140

141
142
143
144
        # Cache the device properties.
        self.device_properties = torch.cuda.get_device_properties(self.device)
        self.num_sms = self.device_properties.multi_processor_count

145
146
147
148
        # Persistent buffers for CUDA graphs.
        self.input_ids = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int32,
                                     device=self.device)
149
150
151
        self.positions = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int64,
                                     device=self.device)
152
        # self.intermediate_tensors  # Set after load_model
153
154

        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
155
        if self.uses_mrope:
Roger Wang's avatar
Roger Wang committed
156
157
158
159
            # NOTE: `mrope_positions` is implemented with one additional dummy
            # position on purpose to make it non-contiguous so that it can work
            # with torch compile.
            # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
160
161
162
163
164
165

            # NOTE: When M-RoPE is enabled, position ids are 3D regardless of
            # the modality of inputs. For text-only inputs, each dimension has
            # identical position IDs, making M-RoPE functionally equivalent to
            # 1D-RoPE.
            # See page 5 of https://arxiv.org/abs/2409.12191
Roger Wang's avatar
Roger Wang committed
166
            self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1),
167
168
                                               dtype=torch.int64,
                                               device=self.device)
Roger Wang's avatar
Roger Wang committed
169
170
171
172
173
            self.mrope_positions_cpu = torch.zeros(
                (3, self.max_num_tokens + 1),
                dtype=torch.int64,
                device="cpu",
                pin_memory=self.pin_memory)
174

175
176
177
178
        self.inputs_embeds = torch.zeros(
            (self.max_num_tokens, self.hidden_size),
            dtype=self.dtype,
            device=self.device)
179

180
        # OPTIMIZATION: Cache the tensors rather than creating them every step.
181
        self.arange_np = np.arange(max(self.max_num_reqs + 1,
182
183
                                       self.max_model_len,
                                       self.max_num_tokens),
184
                                   dtype=np.int32)
185
        self.arange_cpu = torch.from_numpy(self.arange_np)
186
187
188
        # NOTE(woosuk): These tensors are "stateless", i.e., they are literally
        # a faster version of creating a new tensor every time. Thus, we should
        # not make any assumptions about the values in these tensors.
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        self.input_ids_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int32,
                                         device="cpu",
                                         pin_memory=self.pin_memory)
        self.input_ids_np = self.input_ids_cpu.numpy()
        self.positions_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int64,
                                         device="cpu",
                                         pin_memory=self.pin_memory)
        self.positions_np = self.positions_cpu.numpy()
        self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
                                            dtype=torch.int32,
                                            device="cpu",
                                            pin_memory=self.pin_memory)
        self.slot_mapping_np = self.slot_mapping_cpu.numpy()
        self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
                                               dtype=torch.int32,
                                               device="cpu",
                                               pin_memory=self.pin_memory)
        self.query_start_loc_np = self.query_start_loc_cpu.numpy()
209
210
211
212
213
        self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
                                        dtype=torch.int32,
                                        device="cpu",
                                        pin_memory=self.pin_memory)
        self.seq_lens_np = self.seq_lens_cpu.numpy()
214

215
216
217
218
219
220
221
222
223
224
225
226
    def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
        """Update the cached states and the persistent batch with the scheduler
        output.

        The updated states are used by the `_prepare_inputs` function to create
        the input GPU tensors for the model.

        Returns:
            True if there is a new/resumed/paused/finished request in the batch.
            If False, we can skip copying SamplingMetadata to the GPU.
        """
        # Remove finished requests from the cached states.
227
228
        for req_id in scheduler_output.finished_req_ids:
            self.requests.pop(req_id, None)
229
            self.encoder_cache.pop(req_id, None)
230
231
232
233
234
235
236
237
238
239
240
        # Remove the finished requests from the persistent batch.
        # NOTE(woosuk): There could be an edge case where finished_req_ids and
        # scheduled_req_ids overlap. This happens when a request is aborted and
        # then resubmitted with the same ID. In this case, we treat them as two
        # distinct requests - clearing the cached states for the first request
        # and handling the second as a new request.
        removed_req_indices: List[int] = []
        for req_id in scheduler_output.finished_req_ids:
            req_index = self.input_batch.remove_request(req_id)
            if req_index is not None:
                removed_req_indices.append(req_index)
241
242
243
244
245
246
247
248

        # Free the cached encoder outputs.
        for req_id, input_id in scheduler_output.free_encoder_input_ids:
            encoder_outputs = self.encoder_cache.get(req_id)
            if encoder_outputs is not None:
                encoder_outputs.pop(input_id, None)
                if not encoder_outputs:
                    self.encoder_cache.pop(req_id, None)
249

250
251
252
253
254
255
256
257
258
259
260
261
262
        # Remove the unscheduled requests from the persistent batch.
        # NOTE(woosuk): The unscheduled requests are either preempted requests
        # or running requests that are not scheduled in this step. We remove
        # them from the persistent batch but keep their cached states since
        # they will be scheduled again sometime in the future.
        scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
        cached_req_ids = self.input_batch.req_id_to_index.keys()
        unscheduled_req_ids = cached_req_ids - scheduled_req_ids
        # NOTE(woosuk): The persistent batch optimization assumes that
        # consecutive batches contain mostly the same requests. If batches
        # have low request overlap (e.g., alternating between two distinct
        # sets of requests), this optimization becomes very inefficient.
        for req_id in unscheduled_req_ids:
263
            req_index = self.input_batch.remove_request(req_id)
264
265
            assert req_index is not None
            removed_req_indices.append(req_index)
266
267
268

        req_ids_to_add: List[str] = []
        # Add new requests to the cached states.
269
270
271
        for new_req_data in scheduler_output.scheduled_new_reqs:
            req_id = new_req_data.req_id
            sampling_params = new_req_data.sampling_params
272
            if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
273
274
275
276
277
                generator = torch.Generator(device=self.device)
                generator.manual_seed(sampling_params.seed)
            else:
                generator = None

278
279
            self.requests[req_id] = CachedRequestState(
                req_id=req_id,
280
281
282
283
                prompt_token_ids=new_req_data.prompt_token_ids,
                prompt=new_req_data.prompt,
                mm_inputs=new_req_data.mm_inputs,
                mm_positions=new_req_data.mm_positions,
284
285
                sampling_params=sampling_params,
                generator=generator,
286
287
                block_ids=new_req_data.block_ids,
                num_computed_tokens=new_req_data.num_computed_tokens,
288
                output_token_ids=[],
289
                lora_request=new_req_data.lora_request,
290
            )
291
292

            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
293
            if self.uses_mrope:
294
295
                image_grid_thw = []
                video_grid_thw = []
Roger Wang's avatar
Roger Wang committed
296
                second_per_grid_ts = []
297
298
299
300
301
302
303
                for mm_input in self.requests[req_id].mm_inputs:
                    if mm_input.get("image_grid_thw") is not None:
                        image_grid_thw.extend(
                            mm_input["image_grid_thw"].tolist())
                    if mm_input.get("video_grid_thw") is not None:
                        video_grid_thw.extend(
                            mm_input["video_grid_thw"].tolist())
Roger Wang's avatar
Roger Wang committed
304
305
306
                    if mm_input.get("second_per_grid_ts") is not None:
                        second_per_grid_ts.extend(
                            mm_input["second_per_grid_ts"])
307
308
309
310
311
312
313

                hf_config = self.model_config.hf_config

                self.requests[req_id].mrope_positions, \
                    self.requests[req_id].mrope_position_delta = \
                    MRotaryEmbedding.get_input_positions_tensor(
                        self.requests[req_id].prompt_token_ids,
Roger Wang's avatar
Roger Wang committed
314
                        hf_config=hf_config,
315
316
                        image_grid_thw=image_grid_thw,
                        video_grid_thw=video_grid_thw,
Roger Wang's avatar
Roger Wang committed
317
                        second_per_grid_ts=second_per_grid_ts,
318
319
                    )

320
321
            req_ids_to_add.append(req_id)

322
323
324
        # Update the states of the running/resumed requests.
        for req_data in scheduler_output.scheduled_cached_reqs:
            req_id = req_data.req_id
325
326
            req_state = self.requests[req_id]

327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
            # Update the cached states.
            req_state.num_computed_tokens = req_data.num_computed_tokens
            if not req_data.resumed_from_preemption:
                # Append the new blocks to the existing block IDs.
                req_state.block_ids.extend(req_data.new_block_ids)
            else:
                # The request is resumed from preemption.
                # Replace the existing block IDs with the new ones.
                req_state.block_ids = req_data.new_block_ids

            req_index = self.input_batch.req_id_to_index.get(req_id)
            if req_index is None:
                # The request is not in the persistent batch.
                # The request was either preempted and resumed later, or was not
                # scheduled in the previous step and needs to be added again.
                req_ids_to_add.append(req_id)
                continue

            # Update the persistent batch.
            self.input_batch.num_computed_tokens_cpu[req_index] = (
                req_data.num_computed_tokens)
            start_index = len(req_state.block_ids) - len(
                req_data.new_block_ids)
            self.input_batch.block_table.append_row(req_index, start_index,
                                                    req_data.new_block_ids)
352

353
354
        batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0

355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        # Add the new or resumed requests to the persistent batch.
        # The smaller empty indices are filled first.
        removed_req_indices = sorted(removed_req_indices, reverse=True)
        for req_id in req_ids_to_add:
            req_state = self.requests[req_id]
            if removed_req_indices:
                # Fill the empty index.
                req_index = removed_req_indices.pop()
            else:
                # Append to the end.
                req_index = None
            self.input_batch.add_request(req_state, req_index)

        # Condense the batched states if there are empty indices.
        if removed_req_indices:
            self.input_batch.condense(removed_req_indices)
371
372

        return batch_changed
373

374
375
376
    def _prepare_inputs(
        self, scheduler_output: "SchedulerOutput"
    ) -> Tuple[FlashAttentionMetadata, torch.Tensor]:
377
378
379
380
381
382
383
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        assert total_num_scheduled_tokens > 0
        num_reqs = self.input_batch.num_reqs
        assert num_reqs > 0

        # OPTIMIZATION: Start copying the block table first.
        # This way, we can overlap the copy with the following CPU operations.
384
        self.input_batch.block_table.commit(num_reqs)
385
386
387

        # Get the number of scheduled tokens for each request.
        # TODO: The Python loop can be slow. Optimize.
388
        num_scheduled_tokens_list: List[int] = []
389
        max_num_scheduled_tokens = 0
390
391
392
        all_spec_token_ids: List[int] = []
        num_spec_tokens_list: List[int] = []
        for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
393
            assert req_id is not None
394
            num_tokens = scheduler_output.num_scheduled_tokens[req_id]
395
            num_scheduled_tokens_list.append(num_tokens)
396
397
            max_num_scheduled_tokens = max(max_num_scheduled_tokens,
                                           num_tokens)
398
399
400
401
402
            spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
                req_id, [])
            all_spec_token_ids.extend(spec_token_ids)
            num_spec_tokens_list.append(len(spec_token_ids))

403
404
        num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list,
                                                    dtype=np.int32)
405
406
407
408
        assert max_num_scheduled_tokens > 0

        # Get request indices.
        # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
409
410
        req_indices = np.repeat(self.arange_np[:num_reqs],
                                num_scheduled_tokens)
411
412
413

        # Get batched arange.
        # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
414
415
416
417
418
419
420
421
422
        # Equivalent to but faster than:
        # np.concatenate([np.arange(n) for n in num_scheduled_tokens])
        # Step 1. [2, 5, 3] -> [2, 7, 10]
        cu_num_tokens = np.cumsum(num_scheduled_tokens)
        # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
        cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
                                    num_scheduled_tokens)
        # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
423
424

        # Get positions.
425
        positions_np = self.positions_np[:total_num_scheduled_tokens]
426
427
428
429
        np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
               arange,
               out=positions_np)

430
431
        # Calculate M-RoPE positions.
        # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
432
        if self.uses_mrope:
433
434
            self._calc_mrope_positions(scheduler_output)

435
436
437
438
        # Get token indices.
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
        # where M is the max_model_len.
439
440
        token_indices = (positions_np +
                         req_indices * self.input_batch.token_ids_cpu.shape[1])
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513

        use_spec_decode = len(all_spec_token_ids) > 0
        if use_spec_decode:

            # 1. Write spec_token_ids to input batch.
            # Step 1. Get req indices that perform spec decode and repeat
            #         the req indices by the number of spec tokens. Note
            #         for requests that don't perform spec decode, the
            #         number of spec tokens is 0 and the req index is
            #         repeated 0 times.
            # E.g., num_spec_tokens_list:            [3, 0, 2, 0, 1]
            #       spec_req_indices:                [0, 0, 0, 2, 2, 4]
            spec_req_indices = np.repeat(self.arange_np[:num_reqs],
                                         num_spec_tokens_list)
            # spec_offsets: offsets within each spec token list.
            # E.g., [1, 2, 3, 1, 2, 1], TODO: avoid the for loop here
            spec_offsets = np.concatenate(
                [self.arange_np[1:val + 1] for val in num_spec_tokens_list])
            # spec_seq_offsets: offsets within each sequence.
            # E.g., num_computed_tokens_cpu:   [1, 4, 3, 6, 2]
            #       after repeating:           [1, 1, 1, 3, 3, 2]
            #       spec_seq_offsets:  [1, 1, 1, 3, 3, 2] + [1, 2, 3, 1, 2, 1]
            #                                = [2, 3, 4, 4, 5, 3]
            spec_seq_offsets = np.repeat(
                self.input_batch.num_computed_tokens_cpu[:num_reqs],
                num_spec_tokens_list) + spec_offsets
            # cumsums_spec_offsets: [0, 0, 0, 2M, 2M, 4M] + [2, 3, 4, 4, 5, 3]
            cumsums_spec_offsets = (
                spec_seq_offsets +
                spec_req_indices * self.input_batch.token_ids_cpu.shape[1])
            cumsums_spec_offsets = torch.from_numpy(cumsums_spec_offsets).to(
                torch.int64)
            all_spec_token_ids = torch.tensor(all_spec_token_ids,
                                              device="cpu",
                                              dtype=self.input_ids_cpu.dtype)

            # Step 2. Write spec token ids to input_ids_cpu.
            self.input_batch.token_ids_cpu_tensor.flatten().scatter_(
                0, cumsums_spec_offsets, all_spec_token_ids)

            # 2. Get spec decode logits indices.
            # E.g.,   num_scheduled_tokens: [4, 100, 3,   100, 2]
            #         cu_num_tokens:        [4, 104, 107, 207, 209]
            #         num_spec_tokens_list: [3, 0,   2,   0,   1]
            #         num_sampled_tokens:   [4, 1,   3,   1,   2]
            #         spec_decode_logits_indices:
            #                 [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
            num_spec_tokens_np = np.array(num_spec_tokens_list, dtype=np.int32)
            num_sampled_tokens = num_spec_tokens_np + 1
            # logits_start_loc: [0, 103, 104, 206, 207]
            logits_start_loc = cu_num_tokens - num_sampled_tokens
            # [0, 103, 104, 206, 207] ->
            #               [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
            logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens)
            # The following three lines:
            # [4, 1,   3,   1,   2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
            # Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11]
            cu_num_sampled_tokens = np.cumsum(num_sampled_tokens)
            # Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9]
            #         -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
            cumsums_sampled_offsets = np.repeat(
                cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens)
            # Step 3.  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
            #       -  [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
            #      -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
            total_num_sampled_tokens = num_sampled_tokens.sum()
            sampled_arange = (self.arange_np[:total_num_sampled_tokens] -
                              cumsums_sampled_offsets)

            # [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] ->
            # [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
            spec_decode_logits_indices = logits_start_loc + sampled_arange

514
515
516
517
        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
        torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
518
                           0,
519
520
                           torch.from_numpy(token_indices),
                           out=self.input_ids_cpu[:total_num_scheduled_tokens])
521
522

        # Calculate the slot mapping.
523
524
525
526
527
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
        # where K is the max_num_blocks_per_req and the block size is 2.
        # NOTE(woosuk): We can't simply use `token_indices // block_size` here
        # because M (max_model_len) is not necessarily divisible by block_size.
528
529
530
531
532
        block_table_indices = (req_indices * self.max_num_blocks_per_req +
                               positions_np // self.block_size)
        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
533
534
        block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
        block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
535
536
537
538
        block_offsets = positions_np % self.block_size
        np.add(block_numbers * self.block_size,
               block_offsets,
               out=self.slot_mapping_np[:total_num_scheduled_tokens])
539
540

        # Prepare the attention metadata.
541
        self.query_start_loc_np[0] = 0
542
        self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
543

544
545
546
547
        self.seq_lens_np[:num_reqs] = (
            self.input_batch.num_computed_tokens_cpu[:num_reqs] +
            num_scheduled_tokens)
        max_seq_len = self.seq_lens_np[:num_reqs].max()
548
549
550
551

        # Copy the tensors to the GPU.
        self.input_ids[:total_num_scheduled_tokens].copy_(
            self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
552
        if self.uses_mrope:
553
554
555
556
557
558
559
560
561
            # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
            self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
                self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
                non_blocking=True)
        else:
            # Common case (1D positions)
            self.positions[:total_num_scheduled_tokens].copy_(
                self.positions_cpu[:total_num_scheduled_tokens],
                non_blocking=True)
562
563
        query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
            self.device, non_blocking=True)
564
565
        seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device,
                                                   non_blocking=True)
566
567
        slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
            self.device, non_blocking=True).long()
568
569

        # Prepare for cascade attention if needed.
570
571
572
573
574
        common_prefix_len = self._compute_cascade_attn_prefix_len(
            num_scheduled_tokens,
            scheduler_output.num_common_prefix_blocks,
        )
        use_cascade = common_prefix_len > 0
575
576
577
578
579
580
        if use_cascade:
            # TODO: Optimize.
            cu_prefix_query_lens = torch.tensor(
                [0, total_num_scheduled_tokens],
                dtype=torch.int32,
                device=self.device)
581
582
583
584
585
            prefix_kv_lens = torch.tensor([common_prefix_len],
                                          dtype=torch.int32,
                                          device=self.device)
            suffix_kv_lens = (self.seq_lens_np[:num_reqs] - common_prefix_len)
            suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(self.device)
586
587
        else:
            cu_prefix_query_lens = None
588
589
            prefix_kv_lens = None
            suffix_kv_lens = None
590

591
        attn_metadata = FlashAttentionMetadata(
592
            num_actual_tokens=total_num_scheduled_tokens,
593
594
595
            max_query_len=max_num_scheduled_tokens,
            query_start_loc=query_start_loc,
            max_seq_len=max_seq_len,
596
            seq_lens=seq_lens,
597
598
            block_table=(
                self.input_batch.block_table.get_device_tensor()[:num_reqs]),
599
            slot_mapping=slot_mapping,
600
601
602
            use_cascade=use_cascade,
            common_prefix_len=common_prefix_len,
            cu_prefix_query_lens=cu_prefix_query_lens,
603
604
            prefix_kv_lens=prefix_kv_lens,
            suffix_kv_lens=suffix_kv_lens,
605
        )
606

607
608
609
610
611
612
613
614
615
616
617
        if use_spec_decode:
            logits_indices = torch.from_numpy(spec_decode_logits_indices).to(
                self.device, non_blocking=True)
        else:
            # NOTE(woosuk): Due to chunked prefills, the batch may contain
            # partial requests. While we should not sample any token
            # from these partial requests, we do so for simplicity.
            # We will ignore the sampled tokens from the partial requests.
            # TODO: Support prompt logprobs.
            logits_indices = query_start_loc[1:] - 1

618
619
620
621
        # Hot-Swap lora model
        if self.lora_config:
            self.set_active_loras(self.input_batch, num_scheduled_tokens)

622
        return attn_metadata, logits_indices
623

624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
    def _compute_cascade_attn_prefix_len(
        self,
        num_scheduled_tokens: np.ndarray,
        num_common_prefix_blocks: int,
    ) -> int:
        """Compute the length of the common prefix for cascade attention.

        NOTE(woosuk): The common prefix length returned by this function
        represents the length used specifically for cascade attention, not the
        actual number of tokens shared between requests. When cascade attention
        is disabled (use_cascade=False), this function returns 0 even if
        requests share common tokens. Additionally, the common prefix length is
        truncated to a multiple of the block size and may be further truncated
        due to implementation details explained below.

        Args:
            num_scheduled_tokens: Number of tokens scheduled per request.
            num_common_prefix_blocks: Number of shared KV cache blocks.

        Returns:
            int: Length of common prefix in tokens.
        """
        common_prefix_len = num_common_prefix_blocks * self.block_size
        if common_prefix_len == 0:
            # Common case.
            return 0

        # NOTE(woosuk): Cascade attention uses two attention kernels: one
        # for the common prefix and the other for the rest. For the first
        # kernel, we concatenate all the query tokens (possibly from
        # different requests) and treat them as if they are from the same
        # request. Then, we use bi-directional attention to process the
        # common prefix in the KV cache. Importantly, this means that the
        # first kernel does not do any masking.

        # Consider the following example:
        # Request 1's input query: [D, E, X]
        # Request 1's kv cache: [A, B, C, D, E, X]
        # Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
        # Request 2's input query: [E, Y]
        # Request 2's kv cache: [A, B, C, D, E, Y]
        # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])

        # If we use [A, B, C, D, E] as the common prefix, then the
        # first kernel will compute the bi-directional attention between
        # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
        # However, this is wrong because D in Request 1 should not attend to
        # E in the common prefix (i.e., we need masking).
        # To avoid this, [A, B, C, D] should be the common prefix.
        # That is, the common prefix should be capped by the minimum
        # num_computed_tokens among the requests, and plus one to include
        # the first token of the query.

        # In practice, we use [A, B, C] as the common prefix, instead of
        # [A, B, C, D] (i.e., the common prefix is capped by the minimum
        # num_computed_tokens, without plus one).
        # This is because of an implementation detail: We want to always
        # use two kernels for cascade attention. Let's imagine:
        # Request 3's input query: [D]
        # Request 3's kv cache: [A, B, C, D]
        # Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D])
        # If we use [A, B, C, D] as the common prefix for Request 1-3,
        # then Request 3 will be processed only by the first kernel,
        # and the second kernel will get an empty input. While this is not
        # a fundamental problem, our current implementation does not support
        # this case.
        num_reqs = len(num_scheduled_tokens)
        common_prefix_len = min(
            common_prefix_len,
            self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
        # common_prefix_len should be a multiple of the block size.
        common_prefix_len = (common_prefix_len // self.block_size *
                             self.block_size)
        use_cascade = FlashAttentionBackend.use_cascade_attention(
            common_prefix_len=common_prefix_len,
            query_lens=num_scheduled_tokens,
            num_query_heads=self.num_query_heads,
            num_kv_heads=self.num_kv_heads,
            use_alibi=False,  # FIXME
            use_sliding_window=self.sliding_window is not None,
            num_sms=self.num_sms,
        )
        return common_prefix_len if use_cascade else 0

708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
    def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
        mrope_pos_ptr = 0
        num_reqs = self.input_batch.num_reqs
        for index, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
            assert req_id is not None

            req = self.requests[req_id]
            assert req.mrope_positions is not None

            num_computed_tokens = \
                self.input_batch.num_computed_tokens_cpu[index]
            num_scheduled_tokens = \
                scheduler_output.num_scheduled_tokens[req_id]
            num_prompt_tokens = len(req.prompt_token_ids)

            if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
                prompt_part_len = max(0,
                                      num_prompt_tokens - num_computed_tokens)
                completion_part_len = max(
                    0, num_scheduled_tokens - prompt_part_len)
            else:
                prompt_part_len = num_scheduled_tokens
                completion_part_len = 0

            assert num_scheduled_tokens == prompt_part_len + completion_part_len

            if prompt_part_len > 0:
                # prompt's mrope_positions are pre-computed
                dst_start = mrope_pos_ptr
                dst_end = mrope_pos_ptr + prompt_part_len
                src_start = num_computed_tokens
                src_end = num_computed_tokens + prompt_part_len

                self.mrope_positions_cpu[:, dst_start:dst_end] = \
                    req.mrope_positions[:,src_start:src_end]

                mrope_pos_ptr += prompt_part_len

            if completion_part_len > 0:
                # compute completion's mrope_positions on-the-fly
                dst_start = mrope_pos_ptr
                dst_end = mrope_pos_ptr + completion_part_len

                self.mrope_positions_cpu[:, dst_start:dst_end] = \
                    MRotaryEmbedding.get_next_input_positions_tensor(
                        req.mrope_position_delta,
                        context_len=num_computed_tokens +
                        prompt_part_len,
                        seq_len=num_computed_tokens +
                        prompt_part_len +
                        completion_part_len,
                    )

                mrope_pos_ptr += completion_part_len

763
764
    def _prepare_sampling(
        self,
765
        batch_changed: bool,
766
        req_to_spec_token_ids: Dict[str, List[int]],
767
768
    ) -> SamplingMetadata:
        # Create the sampling metadata.
769
770
771
772
773
        req_id_output_token_ids: Dict[str, List[int]] = \
            {req_id: req.output_token_ids \
                for req_id, req in self.requests.items()}

        sampling_metadata = self.input_batch.make_sampling_metadata(
774
            req_id_output_token_ids, req_to_spec_token_ids, not batch_changed)
775
776
        return sampling_metadata

777
778
779
780
781
782
783
    def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
        scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
        if not scheduled_encoder_inputs:
            return

        # Batch the multi-modal inputs.
        mm_inputs: List[MultiModalKwargs] = []
784
        req_input_ids: List[Tuple[str, int]] = []
785
786
787
788
789
        for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
            req_state = self.requests[req_id]
            for input_id in encoder_input_ids:
                mm_inputs.append(req_state.mm_inputs[input_id])
                req_input_ids.append((req_id, input_id))
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817

        # Batch mm inputs as much as we can: if a request in the batch has
        # multiple modalities or a different modality than the previous one,
        # we process it separately to preserve item order.
        # FIXME(ywang96): This is a hacky way to deal with multiple modalities
        # in the same batch while still being able to benefit from batching
        # multimodal inputs. The proper solution should be reordering the
        # encoder outputs.
        grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)

        encoder_outputs = []
        for grouped_mm_inputs in grouped_mm_inputs_list:
            batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
            batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
                                                           device=self.device)

            # Run the encoder.
            # `curr_group_outputs` is either of the following:
            # 1. A tensor of shape (num_items, feature_size, hidden_size)
            # in case feature_size is fixed across all multimodal items.
            # 2. A list or tuple (length: num_items) of tensors, each of shape
            # (feature_size, hidden_size) in case the feature size is dynamic
            # depending on the input multimodal items.
            curr_group_outputs = self.model.get_multimodal_embeddings(
                **batched_mm_inputs)

            for output in curr_group_outputs:
                encoder_outputs.append(output)
818
819
820
821
822
823
824
825
826
827
828
829
830
831

        # Cache the encoder outputs.
        for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
            if req_id not in self.encoder_cache:
                self.encoder_cache[req_id] = {}
            self.encoder_cache[req_id][input_id] = output

    def _gather_encoder_outputs(
        self,
        scheduler_output: "SchedulerOutput",
    ) -> List[torch.Tensor]:
        encoder_outputs: List[torch.Tensor] = []
        num_reqs = self.input_batch.num_reqs
        for req_id in self.input_batch.req_ids[:num_reqs]:
832
            assert req_id is not None
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
                req_id]
            req_state = self.requests[req_id]
            num_computed_tokens = req_state.num_computed_tokens
            mm_positions = req_state.mm_positions
            for i, pos_info in enumerate(mm_positions):
                start_pos = pos_info["offset"]
                num_encoder_tokens = pos_info["length"]

                # The encoder output is needed if the two ranges overlap:
                # [num_computed_tokens,
                #  num_computed_tokens + num_scheduled_tokens) and
                # [start_pos, start_pos + num_encoder_tokens)
                if start_pos >= num_computed_tokens + num_scheduled_tokens:
                    # The encoder output is not needed in this step.
                    break
                if start_pos + num_encoder_tokens <= num_computed_tokens:
                    # The encoder output is already processed and stored
                    # in the decoder's KV cache.
                    continue

                start_idx = max(num_computed_tokens - start_pos, 0)
                end_idx = min(
                    num_computed_tokens - start_pos + num_scheduled_tokens,
                    num_encoder_tokens)
                assert start_idx < end_idx
                assert req_id in self.encoder_cache
                assert i in self.encoder_cache[req_id]
                encoder_output = self.encoder_cache[req_id][i]
                encoder_outputs.append(encoder_output[start_idx:end_idx])
        return encoder_outputs

865
866
867
    def get_model(self) -> nn.Module:
        return self.model

868
869
870
871
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
872
        intermediate_tensors: Optional[IntermediateTensors] = None,
873
    ) -> Union[ModelRunnerOutput, torch.Tensor]:
874
        batch_changed = self._update_states(scheduler_output)
875

876
877
878
879
880
881
        if self.is_multimodal_model:
            # Run the multimodal encoder if any.
            self._execute_encoder(scheduler_output)
            encoder_outputs = self._gather_encoder_outputs(scheduler_output)
        else:
            encoder_outputs = []
882
883

        # Prepare the decoder inputs.
884
        attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
885
886
887
888
889
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        if (self.use_cuda_graph
                and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
            # Use piecewise CUDA graphs.
            # Add padding to the batch size.
890
            num_input_tokens = self.vllm_config.pad_for_cudagraph(
891
892
893
894
                num_scheduled_tokens)
        else:
            # Eager mode.
            num_input_tokens = num_scheduled_tokens
895
896
        attn_metadata.num_input_tokens = num_input_tokens

897
898
899
900
901
902
903
904
905
906
907
908
909
910
        if self.is_multimodal_model:
            # NOTE(woosuk): To unify token ids and soft tokens (vision
            # embeddings), we always use embeddings (rather than token ids)
            # as input to the multimodal model, even when the input is text.
            input_ids = self.input_ids[:num_scheduled_tokens]
            if encoder_outputs:
                inputs_embeds = self.model.get_input_embeddings(
                    input_ids, encoder_outputs)
            else:
                inputs_embeds = self.model.get_input_embeddings(input_ids)
            # TODO(woosuk): Avoid the copy. Optimize.
            self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
            input_ids = None
911
        else:
912
913
914
915
916
917
            # For text-only models, we use token ids as input.
            # While it is possible to use embeddings as input just like the
            # multimodal models, it is not desirable for performance since
            # then the embedding layer is not included in the CUDA graph.
            input_ids = self.input_ids[:num_input_tokens]
            inputs_embeds = None
918
919
920
921
        if self.uses_mrope:
            positions = self.mrope_positions[:, :num_input_tokens]
        else:
            positions = self.positions[:num_input_tokens]
922

923
924
925
926
927
928
929
930
        if get_pp_group().is_first_rank:
            intermediate_tensors = None
        else:
            intermediate_tensors = IntermediateTensors({
                k: v[:num_input_tokens]
                for k, v in self.intermediate_tensors.items()
            })

931
932
        # Run the decoder.
        # Use persistent buffers for CUDA graphs.
933
        with set_forward_context(attn_metadata, self.vllm_config):
934
            hidden_states = self.model(
935
                input_ids=input_ids,
936
                positions=positions,
937
                kv_caches=self.kv_caches,
938
                attn_metadata=None,
939
                intermediate_tensors=intermediate_tensors,
940
                inputs_embeds=inputs_embeds,
941
            )
942
        if not get_pp_group().is_last_rank:
943
            # For mid-pipeline stages, return the hidden states.
944
            return hidden_states
945

946
        hidden_states = hidden_states[:num_scheduled_tokens]
947
948
        sample_hidden_states = hidden_states[logits_indices]
        logits = self.model.compute_logits(sample_hidden_states, None)
949
950

        # Sample the next token and get logprobs if needed.
951
952
        sampling_metadata = self._prepare_sampling(
            batch_changed, scheduler_output.scheduled_spec_decode_tokens)
953
954
955
956
957
958
959
960
        sampler_output = self.model.sample(
            logits=logits,
            sampling_metadata=sampling_metadata,
        )

        # TODO(woosuk): The following loop can be slow since it iterates over
        # the requests one by one. Optimize.
        num_reqs = self.input_batch.num_reqs
961
        request_seq_lens: List[Tuple[int, CachedRequestState, int]] = []
962
        for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
963
            assert req_id is not None
964
965
966
            req_state = self.requests[req_id]
            seq_len = (req_state.num_computed_tokens +
                       scheduler_output.num_scheduled_tokens[req_id])
967
            if seq_len >= req_state.num_tokens:
968
                request_seq_lens.append((i, req_state, seq_len))
969
970
971
            else:
                # Ignore the sampled token from the partial request.
                # Rewind the generator state as if the token was not sampled.
972
                generator = self.input_batch.generators.get(i)
973
                if generator is not None:
974
975
                    # This relies on cuda-specific torch-internal impl details
                    generator.set_offset(generator.get_offset() - 4)
976

977
978
979
980
981
982
983
984
        # num_reqs entries should be non-None
        assert all(
            req_id is not None for req_id in
            self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
        req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])

        # NOTE: GPU -> CPU Sync happens here.
        # Move as many CPU operations as possible before this sync point.
985
986
987
988
989
990
991
992
993
994
        logprobs_tensors = sampler_output.logprobs_tensors
        logprobs_lists = logprobs_tensors.tolists() \
            if logprobs_tensors is not None else None

        # Compute prompt logprobs if needed.
        prompt_logprobs_dict = self._get_prompt_logprobs_dict(
            hidden_states,
            scheduler_output,
        )

995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
        # Update batch with the valid generated tokens.
        sampled_token_ids = sampler_output.sampled_token_ids
        max_gen_len = sampled_token_ids.shape[-1]
        if max_gen_len == 1:
            valid_sampled_token_ids = sampled_token_ids.tolist()
            for i, req_state, seq_len in request_seq_lens:
                token_id = valid_sampled_token_ids[i][0]
                self.input_batch.token_ids_cpu[i, seq_len] = token_id
                req_state.output_token_ids.append(token_id)
                self.input_batch.num_tokens[i] += 1
        else:
            valid_mask = sampled_token_ids != INVALID_TOKEN_ID
            gen_lens = valid_mask.sum(dim=1).tolist()
            valid_sampled_token_ids = [
                seq.tolist()
                for seq in sampled_token_ids[valid_mask].split(gen_lens)
            ]
            self.input_batch.num_tokens[:num_reqs] += gen_lens
            for i, req_state, seq_len in request_seq_lens:
                target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1)
                self.input_batch.token_ids_cpu[
                    i, target_slice] = valid_sampled_token_ids[i]
                req_state.output_token_ids.extend(valid_sampled_token_ids[i])
1018

1019
        model_runner_output = ModelRunnerOutput(
1020
            req_ids=req_ids,
1021
            req_id_to_index=self.input_batch.req_id_to_index,
1022
            sampled_token_ids=valid_sampled_token_ids,
1023
1024
            logprobs=logprobs_lists,
            prompt_logprobs_dict=prompt_logprobs_dict,
1025
1026
1027
1028
1029
1030
        )
        return model_runner_output

    def load_model(self) -> None:
        logger.info("Starting to load model %s...", self.model_config.model)
        with DeviceMemoryProfiler() as m:  # noqa: SIM117
Joe Runde's avatar
Joe Runde committed
1031
            self.model = get_model(vllm_config=self.vllm_config)
1032
1033
1034
1035
1036
1037
            if self.lora_config:
                self.model = self.load_lora_model(self.model,
                                                  self.model_config,
                                                  self.scheduler_config,
                                                  self.lora_config,
                                                  self.device)
1038
1039
1040
1041
1042

        self.model_memory_usage = m.consumed_memory
        logger.info("Loading model weights took %.4f GB",
                    self.model_memory_usage / float(2**30))

1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
    def _get_prompt_logprobs_dict(
        self,
        hidden_states: torch.Tensor,
        scheduler_output: "SchedulerOutput",
    ) -> Dict[str, LogprobsTensors]:
        num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
        if not num_prompt_logprobs_dict:
            return {}

        prompt_logprobs_dict: Dict[str, LogprobsTensors] = {}

        # Since prompt logprobs are a rare feature, prioritize simple,
        # maintainable loop over optimal performance.
        completed_prefill_reqs = []
        for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():

            num_tokens = scheduler_output.num_scheduled_tokens[req_id]

            # Get metadata for this request.
            request = self.requests[req_id]
            num_prompt_tokens = len(request.prompt_token_ids)
            prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
                self.device, non_blocking=True)

            # Determine number of logits to retrieve.
            start_tok = request.num_computed_tokens + 1
            num_remaining_tokens = num_prompt_tokens - start_tok
            if num_tokens < num_remaining_tokens:
                # This is a chunk, more tokens remain.
                num_logits = num_tokens
            else:
                # This is the last chunk of prompt tokens to return.
                num_logits = num_remaining_tokens
                completed_prefill_reqs.append(req_id)

            # Get the logits corresponding to this req's prompt tokens.
            # If this is a partial request (i.e. chunked prefill),
            # then there is prompt logprob generated for each index.
            req_idx = self.input_batch.req_id_to_index[req_id]
            offset = self.query_start_loc_np[req_idx].item()
            prompt_hidden_states = hidden_states[offset:offset + num_logits]
            logits = self.model.compute_logits(prompt_hidden_states, None)

            # Get the "target" tokens for each index. For prompt at index i,
            # the token at prompt index i+1 is the "sampled" token we want
            # to gather the logprob for.
            tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]

            # Compute prompt logprobs.
            logprobs = self.model.sampler.compute_logprobs(logits)
            token_ids, logprobs, ranks = self.model.sampler.gather_logprobs(
                logprobs, num_prompt_logprobs, tgt_token_ids)

            # Transfer GPU->CPU async.
            prompt_logprobs_dict[req_id] = LogprobsTensors(
                token_ids.to("cpu", non_blocking=True),
                logprobs.to("cpu", non_blocking=True),
                ranks.to("cpu", non_blocking=True),
            )

        # Remove requests that have completed prefill from the batch
        # num_prompt_logprobs_dict.
        for req_id in completed_prefill_reqs:
            del num_prompt_logprobs_dict[req_id]

        # Must synchronize the non-blocking GPU->CPU transfers.
        torch.cuda.synchronize()

        return prompt_logprobs_dict

1113
1114
1115
1116
    @torch.inference_mode()
    def _dummy_run(
        self,
        num_tokens: int,
1117
        kv_caches: Optional[List[torch.Tensor]] = None,
1118
    ) -> torch.Tensor:
1119
1120
1121
        model = self.model
        if kv_caches is None:
            kv_caches = self.kv_caches
1122
1123
1124
1125
1126
1127
        if self.is_multimodal_model:
            input_ids = None
            inputs_embeds = self.inputs_embeds[:num_tokens]
        else:
            input_ids = self.input_ids[:num_tokens]
            inputs_embeds = None
1128
1129
1130
1131
        if self.uses_mrope:
            positions = self.mrope_positions[:, :num_tokens]
        else:
            positions = self.positions[:num_tokens]
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146

        if get_pp_group().is_first_rank:
            intermediate_tensors = None
        else:
            if not hasattr(self, "intermediate_tensors"):
                self.intermediate_tensors = (
                    self.model.make_empty_intermediate_tensors(
                        batch_size=self.max_num_tokens,
                        dtype=self.model_config.dtype,
                        device=self.device))
            intermediate_tensors = IntermediateTensors({
                k: v[:num_tokens]
                for k, v in self.intermediate_tensors.items()
            })

1147
        with set_forward_context(None, self.vllm_config):
1148
            hidden_states = model(
1149
                input_ids=input_ids,
1150
                positions=positions,
1151
1152
                kv_caches=kv_caches,
                attn_metadata=None,
1153
                intermediate_tensors=intermediate_tensors,
1154
1155
                inputs_embeds=inputs_embeds,
            )
1156
1157
1158
        return hidden_states

    def profile_run(self) -> None:
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
        # use an empty tensor instead of `None`` to force Dynamo to pass
        # it by reference, rather by specializing on the value `None`.
        # the `dtype` argument does not matter, and we use `float32` as
        # a placeholder (it has wide hardware support).
        # it is important to create tensors inside the loop, rather than
        # multiplying the list, to avoid Dynamo from treating them as
        # tensor aliasing.
        dummy_kv_caches = [
            torch.tensor([], dtype=torch.float32, device=self.device)
            for _ in range(self.num_attn_layers)
        ]
1170
1171

        # Profile with multimodal encoder & encoder cache.
1172
1173
1174
        # TODO: handle encoder-decoder models once we support them.
        if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
                and self.encoder_cache_size > 0):
1175

1176
            # NOTE: Currently model is profiled with a single non-text
1177
1178
            # modality with the max possible input tokens even when
            # it supports multiple.
1179
            max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality(  # noqa: E501
1180
1181
1182
1183
1184
                self.model_config)
            dummy_data_modality, max_tokens_per_mm_item = max(
                max_tokens_by_modality_dict.items(), key=lambda item: item[1])

            # Check how many items of this modality can be supported by
1185
1186
1187
1188
1189
1190
            # the encoder budget.
            encoder_budget = min(self.max_num_encoder_input_tokens,
                                 self.encoder_cache_size)

            max_num_mm_items_encoder_budget = cdiv(encoder_budget,
                                                   max_tokens_per_mm_item)
1191
1192
1193

            # Check how many items of this modality can be supported by
            # the decoder budget.
1194
1195
            max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
                self.model_config)[dummy_data_modality]
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205

            # NOTE: We do not consider max_num_batched_tokens on purpose
            # because the multimodal embeddings can be generated in advance
            # and chunked prefilled.
            max_num_mm_items_decoder_budget = self.max_num_reqs * \
                max_mm_items_per_req

            max_num_mm_items = min(max_num_mm_items_encoder_budget,
                                   max_num_mm_items_decoder_budget)

1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
            logger.info(
                "Encoder cache will be initialized with a budget of %s tokens,"
                " and profiled with %s %s items of the maximum feature size.",
                encoder_budget, max_num_mm_items, dummy_data_modality)

            # Create dummy batch of multimodal inputs.
            dummy_request_data = self.input_registry.dummy_data_for_profiling(
                model_config=self.model_config,
                seq_len=self.max_num_tokens,
                mm_registry=self.mm_registry,
            )
            dummy_mm_data = dummy_request_data.multi_modal_data

1219
1220
1221
1222
            # Dummy data definition in V0 may contain multiple multimodal items
            # (e.g, multiple images) for a single request, therefore here we
            # always replicate first item by max_num_mm_items times since in V1
            # they are scheduled to be processed separately.
1223
1224

            # Case when models have a merged processor, their dummy data is
1225
1226
            # already batched `MultiModalKwargs`, therefore we take the first
            # `MultiModalKwargsItem` from the desired modality to profile on.
1227
            if isinstance(dummy_mm_data, MultiModalKwargs):
1228
1229
1230
                dummy_mm_item = dummy_mm_data.get_item(
                    modality=dummy_data_modality, item_index=0)
                dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
1231
1232
1233
1234

            # Case when models have dummy data explicitly defined as
            # `MultiModalDataDict`, so they need to be processed through input
            # mapper.
1235
1236
            # TODO (ywang96): deprecate this path once merged processor is
            # supported on all models.
1237
            else:
1238
                mm_kwargs_list = self.mm_input_mapper_profiling.process_inputs(
1239
                    mm_data=dummy_mm_data,
1240
                    mm_hashes=None,
1241
1242
1243
1244
                    mm_processor_kwargs=None,
                    precomputed_mm_inputs=None)
                dummy_mm_kwargs = mm_kwargs_list[0]

1245
            batched_dummy_mm_inputs = MultiModalKwargs.batch(
1246
                [dummy_mm_kwargs] * max_num_mm_items)
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
            batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
                batched_dummy_mm_inputs, device=self.device)

            # Run multimodal encoder.
            dummy_encoder_outputs = self.model.get_multimodal_embeddings(
                **batched_dummy_mm_inputs)
            assert len(dummy_encoder_outputs) == max_num_mm_items, (
                "Expected dimension 0 of encoder outputs to match the number "
                f"of multimodal data items: {max_num_mm_items}, got "
                f"{len(dummy_encoder_outputs)=} instead. This is most likely "
                "due to the 'get_multimodal_embeddings' method of the model "
                "not implemented correctly.")

            # Cache the dummy encoder outputs.
            self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))

1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
        # For profile, have maximum num_reqs and that collectively have
        # maximum num_tokens.
        num_reqs = self.scheduler_config.max_num_seqs
        num_tokens = self.max_num_tokens
        min_tokens_per_req: int = num_tokens // num_reqs

        num_scheduled_tokens_list: List[int] = [min_tokens_per_req] * num_reqs
        num_scheduled_tokens_list[-1] += num_tokens % num_reqs
        assert sum(num_scheduled_tokens_list) == num_tokens
        assert len(num_scheduled_tokens_list) == num_reqs

        num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list,
                                                    dtype=np.int32)
        logit_indices = np.cumsum(num_scheduled_tokens) - 1

        with self.maybe_profile_with_lora(self.lora_config,
                                          num_scheduled_tokens):
            # Trigger compilation for general shape.
            hidden_states = self._dummy_run(self.max_num_tokens,
                                            dummy_kv_caches)
1283
1284
1285
1286
1287
1288
            if get_pp_group().is_last_rank:
                hidden_states = hidden_states[logit_indices]
                logits = self.model.compute_logits(hidden_states, None)
                # TODO(woosuk): Consider the memory usage of the sampler.
            else:
                logits = None
1289
1290
1291
            torch.cuda.synchronize()
            del hidden_states, logits
            self.encoder_cache.clear()
1292
        gc.collect()
1293
1294

    def capture_model(self) -> None:
1295
1296
        if not self.use_cuda_graph:
            logger.warning(
1297
                "Skipping CUDA graph capture. Please add "
1298
                "-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE)
1299
1300
1301
1302
1303
            return

        start_time = time.perf_counter()
        start_free_gpu_memory = torch.cuda.mem_get_info()[0]

1304
1305
1306
        # Trigger CUDA graph capture for specific shapes.
        # Capture the large shapes first so that the smaller shapes
        # can reuse the memory pool allocated for the large shapes.
1307
        with graph_capture(device=self.device):
1308
            for num_tokens in reversed(self.cudagraph_batch_sizes):
1309
1310
                for _ in range(self.vllm_config.compilation_config.
                               cudagraph_num_of_warmups):
1311
1312
                    self._dummy_run(num_tokens)
                self._dummy_run(num_tokens)
1313
1314
1315
1316
1317
1318
1319
1320

        end_time = time.perf_counter()
        end_free_gpu_memory = torch.cuda.mem_get_info()[0]
        elapsed_time = end_time - start_time
        cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
        # This usually takes 5~20 seconds.
        logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
                    elapsed_time, cuda_graph_size / (1 << 30))
1321

1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        """
        Initialize KV cache based on `kv_cache_config`.
        Args:
            kv_cache_config: Configuration for the KV cache, including the KV 
            cache size of each layer
        """
        if len(kv_cache_config.groups) > 1:
            raise NotImplementedError(
                "Hybrid models with more than one KV cache type are not "
                "supported yet.")

        kv_caches: Dict[str, torch.Tensor] = {}

        for layer_name, layer_spec in kv_cache_config.kv_cache_spec.items():
            tensor_config = kv_cache_config.tensors[layer_name]
            assert tensor_config.size % layer_spec.page_size_bytes == 0
            num_blocks = tensor_config.size // layer_spec.page_size_bytes
            if isinstance(layer_spec, FullAttentionSpec):
                kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
                    num_blocks, layer_spec.block_size, layer_spec.num_kv_heads,
                    layer_spec.head_size)
                dtype = layer_spec.dtype
                kv_caches[layer_name] = torch.zeros(kv_cache_shape,
                                                    dtype=dtype,
                                                    device=self.device)
            else:
                raise NotImplementedError

1351
        bind_kv_cache(
1352
            kv_caches,
1353
            self.vllm_config.compilation_config.static_forward_context,
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
            self.kv_caches)

    def get_kv_cache_spec(self) -> KVCacheSpec:
        """
        Generates the KVCacheSpec by parsing the kv cache format from each 
        Attention module in the static forward context.
        Returns:
            KVCacheSpec: A dictionary mapping layer names to their KV cache 
            format. Layers that do not need KV cache are not included.
        """

        forward_ctx = self.vllm_config.compilation_config.static_forward_context
        block_size = self.vllm_config.cache_config.block_size
        kv_cache_spec: KVCacheSpec = {}
        for layer_name, attn_module in forward_ctx.items():
            # TODO: Support other attention modules, e.g., sliding window,
            # cross-attention, MLA.
            assert isinstance(attn_module, Attention)
            if attn_module.attn_type == AttentionType.DECODER:
                kv_cache_spec[layer_name] = FullAttentionSpec(
                    block_size=block_size,
                    num_kv_heads=attn_module.num_kv_heads,
                    head_size=attn_module.head_size,
                    dtype=attn_module.dtype,
                )
            elif attn_module.attn_type in (AttentionType.ENCODER,
                                           AttentionType.ENCODER_ONLY):
                # encoder-only attention does not need KV cache.
                continue
            elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
                raise NotImplementedError
            else:
                raise ValueError(
                    f"Unknown attention type: {attn_module.attn_type}")

        return kv_cache_spec