gpu_model_runner.py 39.7 KB
Newer Older
1
import gc
2
import time
3
from typing import TYPE_CHECKING, Dict, List, Tuple, cast
4
5
6
7
8
9

import numpy as np
import torch
import torch.distributed
import torch.nn as nn

10
from vllm.config import CompilationLevel, VllmConfig
11
from vllm.distributed.parallel_state import graph_capture
12
from vllm.forward_context import set_forward_context
13
from vllm.inputs import INPUT_REGISTRY
14
15
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
16
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
17
from vllm.sampling_params import SamplingType
18
19
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
                        LayerBlockType, cdiv, is_pin_memory_available)
20
21
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
                                                   FlashAttentionMetadata)
22
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
23
24
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
25
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
26
27
28
29
30
31
32
33
34
35
36

if TYPE_CHECKING:
    from vllm.v1.core.scheduler import SchedulerOutput

logger = init_logger(__name__)


class GPUModelRunner:

    def __init__(
        self,
37
        vllm_config: VllmConfig,
38
        device: torch.device,
39
    ):
40
41
42
43
44
45
46
47
48
49
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.prompt_adapter_config = vllm_config.prompt_adapter_config
        self.observability_config = vllm_config.observability_config
50

51
52
53
54
        model_config = self.model_config
        cache_config = self.cache_config
        scheduler_config = self.scheduler_config
        parallel_config = self.parallel_config
55
        self.device = device
56
57
58
59
60
61
62
63
        self.pin_memory = is_pin_memory_available()
        self.dtype = self.model_config.dtype
        if cache_config.cache_dtype == "auto":
            self.kv_cache_dtype = self.dtype
        else:
            self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
                cache_config.cache_dtype]

64
        self.is_multimodal_model = model_config.is_multimodal_model
65
66
67
68
69
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.max_model_len = model_config.max_model_len
        self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
        self.max_num_tokens = scheduler_config.max_num_batched_tokens
70
        self.max_num_reqs = scheduler_config.max_num_seqs
71
72

        # Model-related.
73
74
        self.num_attn_layers = model_config.get_num_layers_by_block_type(
            parallel_config, LayerBlockType.attention)
75
76
        self.num_query_heads = model_config.get_num_attention_heads(
            parallel_config)
77
78
        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
        self.head_size = model_config.get_head_size()
79
80
81
        self.hidden_size = model_config.get_hidden_size()

        # Multi-modal data support
82
83
        self.input_registry = INPUT_REGISTRY
        self.mm_registry = MULTIMODAL_REGISTRY
84

85
86
87
88
        # NOTE: Initialized input mapper is only used for processing dummy
        # multimodal data into multimodal kwargs for GPU memory profiling.
        self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
        self.mm_input_mapper_profiling.use_cache = False
89

90
91
        self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens  # noqa: E501
        self.encoder_cache_size = self.scheduler_config.encoder_cache_size
92
93
94
95

        # Lazy initialization
        # self.model: nn.Module  # Set after load_model
        self.kv_caches: List[torch.Tensor] = []
96
97
        # req_id -> (input_id -> encoder_output)
        self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
98
99
100
101
102

        # Request states.
        self.requests: Dict[str, CachedRequestState] = {}
        # Persistent batch.
        self.input_batch = InputBatch(
103
            max_num_reqs=self.max_num_reqs,
104
105
106
107
            max_model_len=self.max_model_len,
            max_num_blocks_per_req=self.max_num_blocks_per_req,
            device=self.device,
            pin_memory=self.pin_memory,
108
            vocab_size=model_config.get_vocab_size(),
109
110
        )

111
        self.use_cuda_graph = (self.vllm_config.compilation_config.level
112
113
114
                               == CompilationLevel.PIECEWISE
                               and not self.model_config.enforce_eager)
        # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
115
116
117
118
119
        # The convention is different.
        # self.cudagraph_batch_sizes sorts in ascending order.
        # The batch sizes in the config are in descending order.
        self.cudagraph_batch_sizes = list(
            reversed(self.vllm_config.compilation_config.capture_sizes))
120

121
122
123
124
        # Cache the device properties.
        self.device_properties = torch.cuda.get_device_properties(self.device)
        self.num_sms = self.device_properties.multi_processor_count

125
126
127
128
        # Persistent buffers for CUDA graphs.
        self.input_ids = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int32,
                                     device=self.device)
129
130
131
        self.positions = torch.zeros(self.max_num_tokens,
                                     dtype=torch.int64,
                                     device=self.device)
132
133
134
135
        self.inputs_embeds = torch.zeros(
            (self.max_num_tokens, self.hidden_size),
            dtype=self.dtype,
            device=self.device)
136

137
        # OPTIMIZATION: Cache the tensors rather than creating them every step.
138
139
        self.arange_np = np.arange(max(self.max_num_reqs + 1,
                                       self.max_model_len),
140
141
142
143
                                   dtype=np.int32)
        # NOTE(woosuk): These tensors are "stateless", i.e., they are literally
        # a faster version of creating a new tensor every time. Thus, we should
        # not make any assumptions about the values in these tensors.
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
        self.input_ids_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int32,
                                         device="cpu",
                                         pin_memory=self.pin_memory)
        self.input_ids_np = self.input_ids_cpu.numpy()
        self.positions_cpu = torch.zeros(self.max_num_tokens,
                                         dtype=torch.int64,
                                         device="cpu",
                                         pin_memory=self.pin_memory)
        self.positions_np = self.positions_cpu.numpy()
        self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
                                            dtype=torch.int32,
                                            device="cpu",
                                            pin_memory=self.pin_memory)
        self.slot_mapping_np = self.slot_mapping_cpu.numpy()
        self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
                                               dtype=torch.int32,
                                               device="cpu",
                                               pin_memory=self.pin_memory)
        self.query_start_loc_np = self.query_start_loc_cpu.numpy()
        self.seq_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
                                             dtype=torch.int32,
                                             device="cpu",
                                             pin_memory=self.pin_memory)
        self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()

170
171
172
173
174
    def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
        # Remove stopped requests from the cached states.
        # Keep the states of the pre-empted requests.
        for req_id in scheduler_output.finished_req_ids:
            self.requests.pop(req_id, None)
175
176
177
178
179
180
181
182
183
            self.encoder_cache.pop(req_id, None)

        # Free the cached encoder outputs.
        for req_id, input_id in scheduler_output.free_encoder_input_ids:
            encoder_outputs = self.encoder_cache.get(req_id)
            if encoder_outputs is not None:
                encoder_outputs.pop(input_id, None)
                if not encoder_outputs:
                    self.encoder_cache.pop(req_id, None)
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212

        # Remove the requests from the persistent batch.
        stopped_req_ids = set().union(
            scheduler_output.preempted_req_ids,
            scheduler_output.finished_req_ids,
        )
        removed_req_indices: List[int] = []
        for req_id in stopped_req_ids:
            req_index = self.input_batch.remove_request(req_id)
            if req_index is not None:
                removed_req_indices.append(req_index)

        # Update the states of the running requests.
        for req_data in scheduler_output.scheduled_running_reqs:
            req_id = req_data.req_id
            req_state = self.requests[req_id]
            req_index = self.input_batch.req_id_to_index[req_id]

            # Update the num_computed_tokens.
            req_state.num_computed_tokens = req_data.num_computed_tokens
            self.input_batch.num_computed_tokens_cpu[req_index] = (
                req_data.num_computed_tokens)

            # Update the block table.
            num_new_blocks = len(req_data.new_block_ids)
            if num_new_blocks == 0:
                continue
            start_index = len(req_state.block_ids)
            req_state.block_ids.extend(req_data.new_block_ids)
213
214
            self.input_batch.block_table.append_row(req_index, start_index,
                                                    req_data.new_block_ids)
215
216
217

        req_ids_to_add: List[str] = []
        # Add new requests to the cached states.
218
219
220
        for new_req_data in scheduler_output.scheduled_new_reqs:
            req_id = new_req_data.req_id
            sampling_params = new_req_data.sampling_params
221
            if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
222
223
224
225
226
                generator = torch.Generator(device=self.device)
                generator.manual_seed(sampling_params.seed)
            else:
                generator = None

227
228
            self.requests[req_id] = CachedRequestState(
                req_id=req_id,
229
230
231
232
                prompt_token_ids=new_req_data.prompt_token_ids,
                prompt=new_req_data.prompt,
                mm_inputs=new_req_data.mm_inputs,
                mm_positions=new_req_data.mm_positions,
233
234
                sampling_params=sampling_params,
                generator=generator,
235
236
                block_ids=new_req_data.block_ids,
                num_computed_tokens=new_req_data.num_computed_tokens,
237
238
239
240
241
                output_token_ids=[],
            )
            req_ids_to_add.append(req_id)

        # Update the cached states of the resumed requests.
242
243
        for res_req_data in scheduler_output.scheduled_resumed_reqs:
            req_id = res_req_data.req_id
244
245
            req_state = self.requests[req_id]

246
247
            req_state.block_ids = res_req_data.block_ids
            req_state.num_computed_tokens = res_req_data.num_computed_tokens
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
            req_ids_to_add.append(req_id)

        # Add the new or resumed requests to the persistent batch.
        # The smaller empty indices are filled first.
        removed_req_indices = sorted(removed_req_indices, reverse=True)
        for req_id in req_ids_to_add:
            req_state = self.requests[req_id]
            if removed_req_indices:
                # Fill the empty index.
                req_index = removed_req_indices.pop()
            else:
                # Append to the end.
                req_index = None
            self.input_batch.add_request(req_state, req_index)

        # Condense the batched states if there are empty indices.
        if removed_req_indices:
            self.input_batch.condense(removed_req_indices)

    def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        assert total_num_scheduled_tokens > 0
        num_reqs = self.input_batch.num_reqs
        assert num_reqs > 0

        # OPTIMIZATION: Start copying the block table first.
        # This way, we can overlap the copy with the following CPU operations.
275
        self.input_batch.block_table.commit(num_reqs)
276
277
278
279
280
281

        # Get the number of scheduled tokens for each request.
        # TODO: The Python loop can be slow. Optimize.
        num_scheduled_tokens = []
        max_num_scheduled_tokens = 0
        for req_id in self.input_batch.req_ids[:num_reqs]:
282
            assert req_id is not None
283
284
285
286
287
288
289
290
291
            num_tokens = scheduler_output.num_scheduled_tokens[req_id]
            num_scheduled_tokens.append(num_tokens)
            max_num_scheduled_tokens = max(max_num_scheduled_tokens,
                                           num_tokens)
        num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32)
        assert max_num_scheduled_tokens > 0

        # Get request indices.
        # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
292
293
        req_indices = np.repeat(self.arange_np[:num_reqs],
                                num_scheduled_tokens)
294
295
296

        # Get batched arange.
        # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
297
298
        arange = np.concatenate(
            [self.arange_np[:n] for n in num_scheduled_tokens])
299
300

        # Get positions.
301
        positions_np = self.positions_np[:total_num_scheduled_tokens]
302
303
304
305
306
307
308
309
        np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
               arange,
               out=positions_np)

        # Get token indices.
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
        # where M is the max_model_len.
310
311
        token_indices = (positions_np +
                         req_indices * self.input_batch.token_ids_cpu.shape[1])
312
313
314
315
        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
        torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
316
                           0,
317
318
                           torch.from_numpy(token_indices),
                           out=self.input_ids_cpu[:total_num_scheduled_tokens])
319
320

        # Calculate the slot mapping.
321
322
323
324
325
        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
        # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
        # where K is the max_num_blocks_per_req and the block size is 2.
        # NOTE(woosuk): We can't simply use `token_indices // block_size` here
        # because M (max_model_len) is not necessarily divisible by block_size.
326
327
328
329
330
        block_table_indices = (req_indices * self.max_num_blocks_per_req +
                               positions_np // self.block_size)
        # NOTE(woosuk): We use torch.index_select instead of np.take here
        # because torch.index_select is much faster than np.take for large
        # tensors.
331
332
        block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
        block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
333
334
335
336
        block_offsets = positions_np % self.block_size
        np.add(block_numbers * self.block_size,
               block_offsets,
               out=self.slot_mapping_np[:total_num_scheduled_tokens])
337
338

        # Prepare the attention metadata.
339
340
341
        self.query_start_loc_np[0] = 0
        np.cumsum(num_scheduled_tokens,
                  out=self.query_start_loc_np[1:num_reqs + 1])
342
343
344
345

        seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] +
                    num_scheduled_tokens)
        max_seq_len = seq_lens.max()
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        self.seq_start_loc_np[0] = 0
        np.cumsum(seq_lens, out=self.seq_start_loc_np[1:num_reqs + 1])

        # Copy the tensors to the GPU.
        self.input_ids[:total_num_scheduled_tokens].copy_(
            self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
        self.positions[:total_num_scheduled_tokens].copy_(
            self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
        query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
            self.device, non_blocking=True)
        seq_start_loc = self.seq_start_loc_cpu[:num_reqs + 1].to(
            self.device, non_blocking=True)
        slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
            self.device, non_blocking=True).long()
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441

        # Prepare for cascade attention if needed.
        common_prefix_len = (scheduler_output.num_common_prefix_blocks *
                             self.block_size)
        if common_prefix_len == 0:
            # Common case.
            use_cascade = False
        else:
            # NOTE(woosuk): Cascade attention uses two attention kernels: one
            # for the common prefix and the other for the rest. For the first
            # kernel, we concatenate all the query tokens (possibly from
            # different requests) and treat them as if they are from the same
            # request. Then, we use bi-directional attention to process the
            # common prefix in the KV cache. Importantly, this means that the
            # first kernel does not do any masking.

            # Consider the following example:
            # Request 1's input query: [D, E, X]
            # Request 1's kv cache: [A, B, C, D, E, X]
            # Request 1's num_computed_tokens: 3 (i.e., [A, B, C])
            # Request 2's input query: [E, Y]
            # Request 2's kv cache: [A, B, C, D, E, Y]
            # Request 2's num_computed_tokens: 4 (i.e., [A, B, C, D])

            # If we use [A, B, C, D, E] as the common prefix, then the
            # first kernel will compute the bi-directional attention between
            # input query [D, E, X, E, Y] and common prefix [A, B, C, D, E].
            # However, this is wrong because D in Request 1 should not attend to
            # E in the common prefix (i.e., we need masking).
            # To avoid this, [A, B, C, D] should be the common prefix.
            # That is, the common prefix should be capped by the minimum
            # num_computed_tokens among the requests, and plus one to include
            # the first token of the query.

            # In practice, we use [A, B, C] as the common prefix, instead of
            # [A, B, C, D] (i.e., the common prefix is capped by the minimum
            # num_computed_tokens, without plus one).
            # This is because of an implementation detail: We want to always
            # use two kernels for cascade attention. Let's imagine:
            # Request 3's input query: [D]
            # Request 3's kv cache: [A, B, C, D]
            # Request 3's num_computed_tokens: 4 (i.e., [A, B, C, D])
            # If we use [A, B, C, D] as the common prefix for Request 1-3,
            # then Request 3 will be processed only by the first kernel,
            # and the second kernel will get an empty input. While this is not
            # a fundamental problem, our current implementation does not support
            # this case.
            common_prefix_len = min(
                common_prefix_len,
                self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
            # common_prefix_len should be a multiple of the block size.
            common_prefix_len = (common_prefix_len // self.block_size *
                                 self.block_size)
            use_cascade = FlashAttentionBackend.use_cascade_attention(
                common_prefix_len=common_prefix_len,
                query_lens=num_scheduled_tokens,
                num_query_heads=self.num_query_heads,
                num_kv_heads=self.num_kv_heads,
                use_alibi=False,  # FIXME
                use_sliding_window=self.sliding_window is not None,
                num_sms=self.num_sms,
            )

        if use_cascade:
            # TODO: Optimize.
            cu_prefix_query_lens = torch.tensor(
                [0, total_num_scheduled_tokens],
                dtype=torch.int32,
                device=self.device)
            cu_prefix_kv_lens = torch.tensor([0, common_prefix_len],
                                             dtype=torch.int32,
                                             device=self.device)
            cu_suffix_kv_lens = (
                self.seq_start_loc_np[:num_reqs + 1] -
                self.arange_np[:num_reqs + 1] * common_prefix_len)
            cu_suffix_kv_lens = torch.from_numpy(cu_suffix_kv_lens).to(
                self.device)
        else:
            cu_prefix_query_lens = None
            cu_prefix_kv_lens = None
            cu_suffix_kv_lens = None

442
        attn_metadata = FlashAttentionMetadata(
443
            num_actual_tokens=total_num_scheduled_tokens,
444
445
446
447
            max_query_len=max_num_scheduled_tokens,
            query_start_loc=query_start_loc,
            max_seq_len=max_seq_len,
            seq_start_loc=seq_start_loc,
448
449
            block_table=(
                self.input_batch.block_table.get_device_tensor()[:num_reqs]),
450
            slot_mapping=slot_mapping,
451
452
453
454
455
            use_cascade=use_cascade,
            common_prefix_len=common_prefix_len,
            cu_prefix_query_lens=cu_prefix_query_lens,
            cu_prefix_kv_lens=cu_prefix_kv_lens,
            cu_suffix_kv_lens=cu_suffix_kv_lens,
456
457
458
459
460
461
462
        )
        # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
        # request in the batch. While we should not sample any token from this
        # partial request, we do so for simplicity. We will ignore the sampled
        # token from the partial request.
        # TODO: Support prompt logprobs.
        logits_indices = query_start_loc[1:] - 1
463
        return attn_metadata, logits_indices
464
465
466
467
468
469
470
471
472
473
474
475
476

    def _prepare_sampling(
        self,
        scheduler_output: "SchedulerOutput",
    ) -> SamplingMetadata:
        skip_copy = True
        if (scheduler_output.finished_req_ids
                or scheduler_output.preempted_req_ids):
            skip_copy = False
        if (scheduler_output.scheduled_new_reqs
                or scheduler_output.scheduled_resumed_reqs):
            skip_copy = False
        # Create the sampling metadata.
477
478
479
480
481
482
        req_id_output_token_ids: Dict[str, List[int]] = \
            {req_id: req.output_token_ids \
                for req_id, req in self.requests.items()}

        sampling_metadata = self.input_batch.make_sampling_metadata(
            req_id_output_token_ids, skip_copy)
483
484
        return sampling_metadata

485
486
487
488
489
490
491
    def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
        scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
        if not scheduled_encoder_inputs:
            return

        # Batch the multi-modal inputs.
        mm_inputs: List[MultiModalKwargs] = []
492
        req_input_ids: List[Tuple[str, int]] = []
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
        for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
            req_state = self.requests[req_id]
            for input_id in encoder_input_ids:
                mm_inputs.append(req_state.mm_inputs[input_id])
                req_input_ids.append((req_id, input_id))
        batched_mm_inputs = MultiModalKwargs.batch(mm_inputs)
        batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
                                                       device=self.device)

        # Run the encoder.
        # `encoder_outputs` is either of the following:
        # 1. A tensor of shape [num_images, feature_size, hidden_size]
        # in case when feature_size is fixed across all images.
        # 2. A list (length: num_images) of tensors, each of shape
        # [feature_size, hidden_size] in case when the feature size is
        # dynamic depending on input images.
509
510
        encoder_outputs = self.model.get_multimodal_embeddings(
            **batched_mm_inputs)
511
512
513
514
515
516
517
518
519
520
521
522
523
524

        # Cache the encoder outputs.
        for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
            if req_id not in self.encoder_cache:
                self.encoder_cache[req_id] = {}
            self.encoder_cache[req_id][input_id] = output

    def _gather_encoder_outputs(
        self,
        scheduler_output: "SchedulerOutput",
    ) -> List[torch.Tensor]:
        encoder_outputs: List[torch.Tensor] = []
        num_reqs = self.input_batch.num_reqs
        for req_id in self.input_batch.req_ids[:num_reqs]:
525
            assert req_id is not None
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
                req_id]
            req_state = self.requests[req_id]
            num_computed_tokens = req_state.num_computed_tokens
            mm_positions = req_state.mm_positions
            for i, pos_info in enumerate(mm_positions):
                start_pos = pos_info["offset"]
                num_encoder_tokens = pos_info["length"]

                # The encoder output is needed if the two ranges overlap:
                # [num_computed_tokens,
                #  num_computed_tokens + num_scheduled_tokens) and
                # [start_pos, start_pos + num_encoder_tokens)
                if start_pos >= num_computed_tokens + num_scheduled_tokens:
                    # The encoder output is not needed in this step.
                    break
                if start_pos + num_encoder_tokens <= num_computed_tokens:
                    # The encoder output is already processed and stored
                    # in the decoder's KV cache.
                    continue

                start_idx = max(num_computed_tokens - start_pos, 0)
                end_idx = min(
                    num_computed_tokens - start_pos + num_scheduled_tokens,
                    num_encoder_tokens)
                assert start_idx < end_idx
                assert req_id in self.encoder_cache
                assert i in self.encoder_cache[req_id]
                encoder_output = self.encoder_cache[req_id][i]
                encoder_outputs.append(encoder_output[start_idx:end_idx])
        return encoder_outputs

558
559
560
561
562
563
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: "SchedulerOutput",
    ) -> ModelRunnerOutput:
        self._update_states(scheduler_output)
564

565
566
567
568
569
570
        if self.is_multimodal_model:
            # Run the multimodal encoder if any.
            self._execute_encoder(scheduler_output)
            encoder_outputs = self._gather_encoder_outputs(scheduler_output)
        else:
            encoder_outputs = []
571
572

        # Prepare the decoder inputs.
573
        attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
574
575
576
577
578
        num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        if (self.use_cuda_graph
                and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
            # Use piecewise CUDA graphs.
            # Add padding to the batch size.
579
            num_input_tokens = self.vllm_config.pad_for_cudagraph(
580
581
582
583
                num_scheduled_tokens)
        else:
            # Eager mode.
            num_input_tokens = num_scheduled_tokens
584
585
        attn_metadata.num_input_tokens = num_input_tokens

586
587
588
589
590
591
592
593
594
595
596
597
598
599
        if self.is_multimodal_model:
            # NOTE(woosuk): To unify token ids and soft tokens (vision
            # embeddings), we always use embeddings (rather than token ids)
            # as input to the multimodal model, even when the input is text.
            input_ids = self.input_ids[:num_scheduled_tokens]
            if encoder_outputs:
                inputs_embeds = self.model.get_input_embeddings(
                    input_ids, encoder_outputs)
            else:
                inputs_embeds = self.model.get_input_embeddings(input_ids)
            # TODO(woosuk): Avoid the copy. Optimize.
            self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
            inputs_embeds = self.inputs_embeds[:num_input_tokens]
            input_ids = None
600
        else:
601
602
603
604
605
606
            # For text-only models, we use token ids as input.
            # While it is possible to use embeddings as input just like the
            # multimodal models, it is not desirable for performance since
            # then the embedding layer is not included in the CUDA graph.
            input_ids = self.input_ids[:num_input_tokens]
            inputs_embeds = None
607
608
609

        # Run the decoder.
        # Use persistent buffers for CUDA graphs.
610
        with set_forward_context(attn_metadata, self.vllm_config):
611
            hidden_states = self.model(
612
                input_ids=input_ids,
613
                positions=self.positions[:num_input_tokens],
614
                kv_caches=self.kv_caches,
615
                attn_metadata=None,
616
                inputs_embeds=inputs_embeds,
617
            )
618
        hidden_states = hidden_states[:num_scheduled_tokens]
619
620
621
622
623
624
625
626
627
628
        hidden_states = hidden_states[logits_indices]
        logits = self.model.compute_logits(hidden_states, None)

        # Sample the next token and get logprobs if needed.
        sampling_metadata = self._prepare_sampling(scheduler_output)
        sampler_output = self.model.sample(
            logits=logits,
            sampling_metadata=sampling_metadata,
        )

629
        sampled_token_ids = sampler_output.sampled_token_ids
630
631
632
633
        # TODO(woosuk): The following loop can be slow since it iterates over
        # the requests one by one. Optimize.
        num_reqs = self.input_batch.num_reqs
        for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
634
            assert req_id is not None
635
636
637
638
639
640
            req_state = self.requests[req_id]
            seq_len = (req_state.num_computed_tokens +
                       scheduler_output.num_scheduled_tokens[req_id])
            assert seq_len <= req_state.num_tokens
            if seq_len == req_state.num_tokens:
                # Append the sampled token to the output token ids.
641
                token_id = sampled_token_ids[i]
642
                self.input_batch.token_ids_cpu[i, seq_len] = token_id
643
                self.input_batch.num_tokens[i] += 1
644
645
646
647
                req_state.output_token_ids.append(token_id)
            else:
                # Ignore the sampled token from the partial request.
                # Rewind the generator state as if the token was not sampled.
648
                generator = self.input_batch.generators.get(i)
649
                if generator is not None:
650
651
                    # This relies on cuda-specific torch-internal impl details
                    generator.set_offset(generator.get_offset() - 4)
652
653
654
655
656
657
658
659
660

        if sampler_output.logprob_token_ids is None:
            logprob_token_ids = None
        else:
            logprob_token_ids = sampler_output.logprob_token_ids.cpu()
        if sampler_output.logprobs is None:
            logprobs = None
        else:
            logprobs = sampler_output.logprobs.cpu()
661
662
663
664
665
666
667

        # num_reqs entries should be non-None
        assert all(
            req_id is not None for req_id in
            self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
        req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])

668
        model_runner_output = ModelRunnerOutput(
669
            req_ids=req_ids,
670
            req_id_to_index=self.input_batch.req_id_to_index,
671
            sampled_token_ids=sampled_token_ids,
672
673
674
675
676
677
678
679
            logprob_token_ids_cpu=logprob_token_ids,
            logprobs_cpu=logprobs,
        )
        return model_runner_output

    def load_model(self) -> None:
        logger.info("Starting to load model %s...", self.model_config.model)
        with DeviceMemoryProfiler() as m:  # noqa: SIM117
Joe Runde's avatar
Joe Runde committed
680
            self.model = get_model(vllm_config=self.vllm_config)
681
682
683
684
685

        self.model_memory_usage = m.consumed_memory
        logger.info("Loading model weights took %.4f GB",
                    self.model_memory_usage / float(2**30))

686
687
688
689
690
691
692
    @torch.inference_mode()
    def _dummy_run(
        self,
        model: nn.Module,
        num_tokens: int,
        kv_caches: List[torch.Tensor],
    ) -> torch.Tensor:
693
694
695
696
697
698
        if self.is_multimodal_model:
            input_ids = None
            inputs_embeds = self.inputs_embeds[:num_tokens]
        else:
            input_ids = self.input_ids[:num_tokens]
            inputs_embeds = None
699
        with set_forward_context(None, self.vllm_config):
700
            hidden_states = model(
701
                input_ids=input_ids,
702
703
704
                positions=self.positions[:num_tokens],
                kv_caches=kv_caches,
                attn_metadata=None,
705
706
                inputs_embeds=inputs_embeds,
            )
707
708
709
        return hidden_states

    def profile_run(self) -> None:
710
711
712
713
714
715
716
717
718
719
720
        # use an empty tensor instead of `None`` to force Dynamo to pass
        # it by reference, rather by specializing on the value `None`.
        # the `dtype` argument does not matter, and we use `float32` as
        # a placeholder (it has wide hardware support).
        # it is important to create tensors inside the loop, rather than
        # multiplying the list, to avoid Dynamo from treating them as
        # tensor aliasing.
        dummy_kv_caches = [
            torch.tensor([], dtype=torch.float32, device=self.device)
            for _ in range(self.num_attn_layers)
        ]
721
722
723
724
725
726
727
728
729
730
731

        # Profile with multimodal encoder & encoder cache.
        if self.is_multimodal_model:

            # Create dummy batch of multimodal inputs.
            dummy_request_data = self.input_registry.dummy_data_for_profiling(
                model_config=self.model_config,
                seq_len=self.max_num_tokens,
                mm_registry=self.mm_registry,
            )
            dummy_mm_data = dummy_request_data.multi_modal_data
732

733
            # NOTE: Currently model is profiled with a single non-text
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
            # modality with the max possible input tokens even when
            # it supports multiple.
            max_tokens_by_modality_dict = self.mm_registry.get_max_tokens_per_item_by_modality(  # noqa: E501
                self.model_config)

            dummy_data_modality, max_tokens_per_mm_item = max(
                max_tokens_by_modality_dict.items(), key=lambda item: item[1])

            # Check how many items of this modality can be supported by
            # the encoder cache budget.
            encoder_cache_budget = min(self.max_num_encoder_input_tokens,
                                       self.encoder_cache_size)
            max_num_mm_items_encoder_budget = encoder_cache_budget // \
                max_tokens_per_mm_item

            # TODO: Allow users to set encoder_cache_budget in case this
            # happens.
            assert max_num_mm_items_encoder_budget > 0, (
                f"Encoder cache budget={encoder_cache_budget} is too small to "
                f"support the maximum possible size of multimodal embeddings"
                f"={max_tokens_per_mm_item}.")

            # Check how many items of this modality can be supported by
            # the decoder budget.
758
759
760
761
762
763
764
765
766
767
768
769
770
            max_mm_items_per_req = max(
                self.mm_registry.get_mm_limits_per_prompt(
                    self.model_config).values())

            # NOTE: We do not consider max_num_batched_tokens on purpose
            # because the multimodal embeddings can be generated in advance
            # and chunked prefilled.
            max_num_mm_items_decoder_budget = self.max_num_reqs * \
                max_mm_items_per_req

            max_num_mm_items = min(max_num_mm_items_encoder_budget,
                                   max_num_mm_items_decoder_budget)

771
772
773
774
            # Dummy data definition in V0 may contain multiple multimodal items
            # (e.g, multiple images) for a single request, therefore here we
            # always replicate first item by max_num_mm_items times since in V1
            # they are scheduled to be processed separately.
775
776

            # Case when models have a merged processor, their dummy data is
777
778
            # already batched `MultiModalKwargs`, therefore we take the first
            # `MultiModalKwargsItem` from the desired modality to profile on.
779
            if isinstance(dummy_mm_data, MultiModalKwargs):
780
781
782
                dummy_mm_item = dummy_mm_data.get_item(
                    modality=dummy_data_modality, item_index=0)
                dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
783
784
785
786

            # Case when models have dummy data explicitly defined as
            # `MultiModalDataDict`, so they need to be processed through input
            # mapper.
787
788
            # TODO (ywang96): deprecate this path once merged processor is
            # supported on all models.
789
            else:
790
                mm_kwargs_list = self.mm_input_mapper_profiling.process_inputs(
791
                    mm_data=dummy_mm_data,
792
                    mm_hashes=None,
793
794
795
796
                    mm_processor_kwargs=None,
                    precomputed_mm_inputs=None)
                dummy_mm_kwargs = mm_kwargs_list[0]

797
            batched_dummy_mm_inputs = MultiModalKwargs.batch(
798
                [dummy_mm_kwargs] * max_num_mm_items)
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
            batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
                batched_dummy_mm_inputs, device=self.device)

            # Run multimodal encoder.
            dummy_encoder_outputs = self.model.get_multimodal_embeddings(
                **batched_dummy_mm_inputs)
            assert len(dummy_encoder_outputs) == max_num_mm_items, (
                "Expected dimension 0 of encoder outputs to match the number "
                f"of multimodal data items: {max_num_mm_items}, got "
                f"{len(dummy_encoder_outputs)=} instead. This is most likely "
                "due to the 'get_multimodal_embeddings' method of the model "
                "not implemented correctly.")

            # Cache the dummy encoder outputs.
            self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))

815
816
817
        # Trigger compilation for general shape.
        hidden_states = self._dummy_run(self.model, self.max_num_tokens,
                                        dummy_kv_caches)
818
819
820
        logits = self.model.compute_logits(hidden_states, None)
        logits = logits[:self.max_num_tokens]
        # TODO(woosuk): Consider the memory usage of the sampler.
821
        torch.cuda.synchronize()
822
        del hidden_states, logits
823
        self.encoder_cache.clear()
824
        gc.collect()
825
826

    def capture_model(self) -> None:
827
828
        if not self.use_cuda_graph:
            logger.warning(
829
                "Skipping CUDA graph capture. Please add "
830
                "-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE)
831
832
833
834
835
            return

        start_time = time.perf_counter()
        start_free_gpu_memory = torch.cuda.mem_get_info()[0]

836
837
838
        # Trigger CUDA graph capture for specific shapes.
        # Capture the large shapes first so that the smaller shapes
        # can reuse the memory pool allocated for the large shapes.
839
        with graph_capture(device=self.device):
840
            for num_tokens in reversed(self.cudagraph_batch_sizes):
841
842
843
                for _ in range(self.vllm_config.compilation_config.
                               cudagraph_num_of_warmups):
                    self._dummy_run(self.model, num_tokens, self.kv_caches)
844
                self._dummy_run(self.model, num_tokens, self.kv_caches)
845
846
847
848
849
850
851
852

        end_time = time.perf_counter()
        end_free_gpu_memory = torch.cuda.mem_get_info()[0]
        elapsed_time = end_time - start_time
        cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
        # This usually takes 5~20 seconds.
        logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
                    elapsed_time, cuda_graph_size / (1 << 30))
853
854
855
856
857
858
859
860
861
862

    def initialize_kv_cache(self, num_blocks: int) -> None:
        assert len(self.kv_caches) == 0
        kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
            num_blocks, self.block_size, self.num_kv_heads, self.head_size)
        for _ in range(self.num_attn_layers):
            self.kv_caches.append(
                torch.zeros(kv_cache_shape,
                            dtype=self.kv_cache_dtype,
                            device=self.device))