model_runner.py 45.6 KB
Newer Older
1
import time
2
import warnings
3
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
4

5
import numpy as np
6
import torch
7
import torch.nn as nn
8

9
from vllm.attention import AttentionMetadata, get_attn_backend
10
11
12
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
                         ModelConfig, ParallelConfig, SchedulerConfig,
                         VisionLanguageConfig)
13
from vllm.distributed import broadcast_tensor_dict
14
from vllm.distributed.communication_op import graph_capture
15
from vllm.logger import init_logger
16
17
18
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
19
from vllm.model_executor import SamplingMetadata
20
from vllm.model_executor.model_loader import get_model
21
from vllm.sampling_params import SamplingParams
22
23
from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData,
                           SequenceGroupMetadata)
24
25
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
                        is_pin_memory_available, make_tensor_with_pad)
26
27
28
29

logger = init_logger(__name__)

_PAD_SLOT_ID = -1
30
LORA_WARMUP_RANK = 8
31
32
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
33
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
34
35
36
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
    _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
37
38


39
40
41
42
class ModelInput(NamedTuple):
    input_tokens: torch.Tensor
    input_positions: torch.Tensor
    attn_metadata: Optional[AttentionMetadata]
43
44
    seq_lens: List[int]
    query_lens: List[int]
45
    lora_mapping: Optional[LoRAMapping]
46
47
    lora_requests: Set[LoRARequest]
    multi_modal_input: Optional[torch.Tensor]
48
49
50
51
    slot_mapping: torch.Tensor
    num_prefill_tokens: int
    num_decode_tokens: int
    num_prefills: int
52
53

    @classmethod
54
55
56
57
    def empty(cls, device):
        return ModelInput(
            input_tokens=torch.empty(0, device=device),
            input_positions=torch.empty(0, device=device),
58
            attn_metadata=None,
59
60
            seq_lens=[],
            query_lens=[],
61
            lora_mapping=None,
62
63
            lora_requests=set(),
            multi_modal_input=None,
64
65
66
67
            slot_mapping=torch.empty(0, device=device),
            num_prefill_tokens=0,
            num_decode_tokens=0,
            num_prefills=0,
68
69
70
        )


71
72
73
74
75
76
77
class ModelRunner:

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
78
        device_config: DeviceConfig,
79
        cache_config: CacheConfig,
80
        load_config: LoadConfig,
81
        lora_config: Optional[LoRAConfig],
82
        kv_cache_dtype: Optional[str] = "auto",
83
        is_driver_worker: bool = False,
84
        vision_language_config: Optional[VisionLanguageConfig] = None,
85
86
87
88
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
89
90
        self.device_config = device_config
        self.cache_config = cache_config
91
        self.lora_config = lora_config
92
        self.load_config = load_config
93
        self.is_driver_worker = is_driver_worker
94
        self.vision_language_config = vision_language_config
95

96
        self.device = self.device_config.device
97
        self.pin_memory = is_pin_memory_available()
98

99
100
101
102
        self.kv_cache_dtype = kv_cache_dtype
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
103
        self.graph_runners: Dict[int, CUDAGraphRunner] = {}
104
105
106
        self.graph_memory_pool: Optional[Tuple[
            int, int]] = None  # Set during graph capture.
        # When using CUDA graph, the input block tables must be padded to
107
        # max_seq_len_to_capture. However, creating the block table in
108
109
110
111
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
        # (max batch size to capture, max context len to capture / block size).
112
113
114
        self.graph_block_tables = np.zeros(
            (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
            dtype=np.int32)
115
116
117
118
119
120
121
122
123
        self.attn_backend = get_attn_backend(
            self.model_config.get_num_attention_heads(self.parallel_config),
            self.model_config.get_head_size(),
            self.model_config.get_num_kv_heads(self.parallel_config),
            self.model_config.get_sliding_window(),
            self.model_config.dtype,
            self.kv_cache_dtype,
            self.block_size,
        )
124

125
        # Lazy initialization
126
        self.model: nn.Module  # Set after load_model
127
128
        # Set if the backend is flashinfer.
        self.flashinfer_workspace_buffer: torch.Tensor
129
130
        # Set after load_model.
        self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
131

132
    def load_model(self) -> None:
133
        with CudaMemoryProfiler() as m:
134
            self.model = get_model(
135
136
137
                model_config=self.model_config,
                device_config=self.device_config,
                load_config=self.load_config,
138
139
140
                lora_config=self.lora_config,
                vision_language_config=self.vision_language_config,
                parallel_config=self.parallel_config,
141
                scheduler_config=self.scheduler_config,
142
                cache_config=self.cache_config,
143
            )
144
145

        self.model_memory_usage = m.consumed_memory
146
147
        logger.info("Loading model weights took %.4f GB",
                    self.model_memory_usage / float(2**30))
148
149

        if self.lora_config:
150
151
152
            assert hasattr(self.model, "supported_lora_modules"
                           ) and self.model.supported_lora_modules, (
                               "Model does not support LoRA")
Terry's avatar
Terry committed
153
154
155
156
157
            assert hasattr(
                self.model,
                "embedding_modules"), "Model does not have embedding_modules"
            assert hasattr(self.model, "embedding_padding_modules"
                           ), "Model does not have embedding_padding_modules"
158
159
            self.lora_manager = LRUCacheWorkerLoRAManager(
                self.scheduler_config.max_num_seqs,
160
161
162
163
164
165
166
167
168
                self.scheduler_config.max_num_batched_tokens,
                self.vocab_size,
                self.lora_config,
                self.device,
                self.model.embedding_modules,
                self.model.embedding_padding_modules,
                max_position_embeddings=self.model.config.
                max_position_embeddings,
            )
169
            self.model = self.lora_manager.create_lora_manager(self.model)
170

171
        if self.kv_cache_dtype == "fp8" and is_hip():
172
173
174
            # Currently only ROCm accepts kv-cache scaling factors
            # via quantization_param_path and this will be deprecated
            # in the future.
175
176
            if self.model_config.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
177
178
179
180
181
182
                    warnings.warn(
                        "Loading kv cache scaling factor from JSON is "
                        "deprecated and will be removed. Please include "
                        "kv cache scaling factors in the model checkpoint.",
                        FutureWarning,
                        stacklevel=2)
183
184
                    self.model.load_kv_cache_scales(
                        self.model_config.quantization_param_path)
185
186
                    logger.info("Loaded KV cache scaling factors from %s",
                                self.model_config.quantization_param_path)
187
                else:
188
189
190
191
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__)
192
            else:
193
194
195
196
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!")
197

198
199
200
201
202
203
204
205
206
207
208
209
210
211
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        from vllm.model_executor.model_loader.loader import ShardedStateLoader
        ShardedStateLoader.save_model(
            self.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

212
213
    def get_max_block_per_batch(self) -> int:
        block_size = self.block_size
214
        return (self.max_seq_len_to_capture + block_size - 1) // block_size
215

216
    def _prepare_model_input(
217
218
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
219
220
221
222
223
224
225
226
227
228
229
230
231
    ) -> ModelInput:
        """Prepare the model input based on a given sequence group.

        The API assumes seq_group_metadata_list is sorted by prefill -> decode.

        The result tensors and data structure also batches input in prefill
        -> decode order. For example,

        - input_tokens[:num_prefill_tokens] contains prefill tokens.
        - input_tokens[num_prefill_tokens:] contains decode tokens.

        If cuda graph is required, this API automatically pads inputs.
        """
232
233
234
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
235
236
237
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
238

239
        seq_lens: List[int] = []
240
241
        prefill_seq_lens: List[int] = []
        decode_seq_lens: List[int] = []
242
        context_lens: List[int] = []
243
        query_lens: List[int] = []
244
        block_tables: List[List[int]] = []
245
246
247
248
249
        multi_modal_input_list: List[torch.Tensor] = []
        decode_only = True
        num_prefills = 0
        num_prefill_tokens = 0
        num_decode_tokens = 0
250

251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
        # The following fields are only for flashinfer
        # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
        # for the precise definition of the following fields.
        # An example:
        # request 1, page indices [0, 5, 8]
        # request 2, page indices [1, 6, 7]
        # request 3, page indices [3, 4]
        # paged_kv_indices is a concatenation of page indices of all requests:
        # [0, 5, 8, 1, 6, 7, 3, 4]
        # paged_kv_indptr is used to index into paged_kv_indices:
        # [0, 3, 6, 8]
        paged_kv_indices: List[int] = []
        # 0 at the beginning of paged_kv_indptr indicates the start of the
        # first request’s page indices in the paged_kv_indices list.
        paged_kv_indptr: List[int] = [0]
        # paged_kv_last_page_len is the length of the last page of each request
        paged_kv_last_page_len: List[int] = []

269
        if len(seq_group_metadata_list) == 0:
270
            return ModelInput.empty(self.device)
271

272
273
274
275
276
277
        if self.sliding_window is not None:
            sliding_window_blocks = (self.sliding_window + self.block_size -
                                     1) // self.block_size
            block_aligned_sliding_window = \
                sliding_window_blocks * self.block_size

278
279
        for seq_group_metadata in seq_group_metadata_list:
            seq_ids = list(seq_group_metadata.seq_data.keys())
280
            is_prompt = seq_group_metadata.is_prompt
281

282
            for seq_id in seq_ids:
283
284
285
286
287
288
289
290
291
                computed_block_nums = seq_group_metadata.computed_block_nums
                if (self.scheduler_config is not None
                        and self.scheduler_config.chunked_prefill_enabled
                        and not (computed_block_nums is None
                                 or computed_block_nums == [])):
                    raise RuntimeError(
                        "chunked prefill cannot be used with prefix caching "
                        "now.")

292
                seq_data = seq_group_metadata.seq_data[seq_id]
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
                if is_prompt:
                    context_len = seq_data.get_num_computed_tokens()
                else:
                    # get_num_computed_tokens is incorrect for spec decoding.
                    # So, we should have a special logic here.
                    # TODO(sang): Fix it.
                    context_len = seq_data.get_len() - 1

                seq_len = min(
                    seq_data.get_len(),
                    context_len + seq_group_metadata.token_chunk_size)
                if is_prompt:
                    tokens = seq_data.get_token_ids()[context_len:seq_len]
                else:
                    # Optimization. get_token_ids requires the entire copy of
                    # tokens.
                    tokens = [seq_data.get_last_token_id()]
310

311
312
313
314
315
316
317
                # Prefix cache was hit.
                # Prefix is not supported with sliding_window
                prefix_cache_hit = (computed_block_nums is not None
                                    and len(computed_block_nums) > 0
                                    and self.sliding_window is None
                                    and is_prompt)

318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
                # These are seq_len/context_len capped to the sliding window.
                # They are passed to decode kernel.
                # We still need original seq_len/context_len to compute slot
                # mapping (and input position) below.
                curr_sliding_window_blocks = None
                sliding_seq_len = seq_len
                sliding_context_len = context_len

                # TODO(sang): This is a hack to make sliding window work with
                # paged attn. We can remove it if we make paged attn kernel
                # to properly handle slinding window attn.
                if (self.sliding_window is not None and not is_prompt):
                    curr_sliding_window_blocks = sliding_window_blocks
                    if self.scheduler_config.use_v2_block_manager:
                        # number of elements in last block
                        suff_len = seq_len % self.block_size
                        sliding_seq_len = min(
                            seq_len, block_aligned_sliding_window + suff_len)
                        if suff_len > 0:
                            curr_sliding_window_blocks += 1
                    else:
                        sliding_seq_len = min(seq_len, self.sliding_window)
                    sliding_context_len = sliding_seq_len - 1

342
343
344
345
346
347
348
                # TODO(sang): Combine chunked prefill and prefix caching by
                # only allowing multiple of block_size chunk size.
                # NOTE: This only works for oooooooxxx style attention.
                if prefix_cache_hit:
                    assert computed_block_nums is not None
                    context_len = len(computed_block_nums) * self.block_size
                    tokens = tokens[context_len:]
349
350
351
352
353
354
355

                    # need to think what to set it to when we have both sliding
                    # window and prefix caching...
                    assert self.sliding_window is None, \
                        "Prefix caching is not supported with sliding window"
                    sliding_context_len = context_len

356
357
358
359
360
361
362
363
364
365
366
367
368
                    if self.attn_backend.get_name() == "flash-attn":
                        # NOTE(woosuk): For flash-attn, the block table should
                        # include the entries for the incoming prefill tokens.
                        # TODO(woosuk): This is a temporary fix. We should
                        # provide a unified interface for different backends.
                        block_table = seq_group_metadata.block_tables[seq_id]
                    else:
                        block_table = computed_block_nums
                elif (self.scheduler_config.chunked_prefill_enabled
                      or not is_prompt):
                    if seq_group_metadata.block_tables is not None:
                        # chunked prefill or decode
                        block_table = seq_group_metadata.block_tables[seq_id]
369
370
371
                        if curr_sliding_window_blocks is not None:
                            block_table = block_table[
                                -curr_sliding_window_blocks:]
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
                        if self.attn_backend.get_name() == "flashinfer":
                            paged_kv_indices.extend(block_table)
                            paged_kv_indptr.append(paged_kv_indptr[-1] +
                                                   len(block_table))
                            last_page_len = seq_data.get_len(
                            ) % self.block_size
                            if last_page_len == 0:
                                last_page_len = self.block_size
                            paged_kv_last_page_len.append(last_page_len)
                    else:
                        # Only happens when memory profiling runs.
                        block_table = []
                else:
                    # Prefill without chunked prefill or memory profiling.
                    block_table = []
                block_tables.append(block_table)

389
390
391
                seq_lens.append(sliding_seq_len)
                context_lens.append(sliding_context_len)
                query_len = sliding_seq_len - sliding_context_len
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
                query_lens.append(query_len)
                input_tokens.extend(tokens)
                input_positions.extend(list(range(context_len, seq_len)))
                lora_id = seq_group_metadata.lora_int_id

                if is_prompt:
                    assert len(seq_ids) == 1
                    num_prefills += 1
                    num_prefill_tokens += len(tokens)
                    decode_only = False
                    prefill_seq_lens.append(seq_len)
                else:
                    assert query_len == 1, (
                        "seq_len: {}, context_len: {}, query_len: {}".format(
                            seq_len, context_len, query_len))
                    num_decode_tokens += query_len
408
                    decode_seq_lens.append(sliding_seq_len)
409
410
411
412

                if lora_id > 0:
                    lora_requests.add(seq_group_metadata.lora_request)

413
                lora_index_mapping += [lora_id] * query_len
414
415
                lora_prompt_mapping.extend(
                    [lora_id] *
416
                    (query_len if seq_group_metadata.sampling_params
417
418
419
420
421
422
423
424
425
426
427
428
429
430
                     and seq_group_metadata.sampling_params.prompt_logprobs
                     else 1))

                if seq_group_metadata.multi_modal_data:
                    multi_modal_input_list.append(
                        seq_group_metadata.multi_modal_data.data)

                if _is_block_tables_empty(seq_group_metadata.block_tables):
                    # During memory profiling, the block tables are not
                    # initialized yet. In this case, we just use a dummy
                    # slot mapping.
                    # In embeddings, the block tables are {seq_id: None}.
                    slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
                    continue
431

432
                # Compute the slot mapping.
433
434
                block_table = seq_group_metadata.block_tables[seq_id]

435
436
437
438
439
440
441
                # Mask the [0, start_idx) tokens of the prompt with
                # _PAD_SLOT_ID, where start_idx is max(0, seq_len -
                # sliding_window). For example, if the prompt len is 10,
                # sliding window is 8, and block size is 4, the first two
                # tokens are masked and the slot mapping will be
                # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
                start_idx = 0
442
                if self.sliding_window is not None:
443
                    if is_prompt:
444
445
                        assert self.scheduler_config.use_v2_block_manager \
                            or context_len == 0, (
446
                            "Prefix caching is currently not supported with "
447
                            "sliding window attention in V1 block manager")
448
449
450
451
452
453
454
455
456
457
458
459
460
461
                    # It is an optimization. When it is decoding, it is always
                    # 0. When prefill, we use it to not write slots to kv cache
                    # to save memory.
                    start_idx = max(0, query_len - self.sliding_window)

                for i in range(context_len, seq_len):
                    if i < start_idx:
                        slot_mapping.append(_PAD_SLOT_ID)
                        continue

                    block_number = block_table[i // self.block_size]
                    block_offset = i % self.block_size
                    slot = block_number * self.block_size + block_offset
                    slot_mapping.append(slot)
462

463
464
465
466
        batch_size = len(input_tokens)
        max_query_len = max(query_lens)
        max_prefill_seq_len = max(prefill_seq_lens, default=0)
        max_decode_seq_len = max(decode_seq_lens, default=0)
467

468
        # If cuda graph can be used, pad tensors accordingly.
469
        # See `capture_model` API for more details.
470
471
472
473
474
        # vLLM uses cuda graph only for decoding requests.
        use_captured_graph = (
            decode_only and not self.model_config.enforce_eager
            and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
            and max_decode_seq_len <= self.max_seq_len_to_capture)
475
476
477
478
        if use_captured_graph:
            graph_batch_size = _get_graph_batch_size(batch_size)
            assert graph_batch_size >= batch_size
            for _ in range(graph_batch_size - batch_size):
479
480
481
                input_tokens.append(0)
                input_positions.append(0)
                slot_mapping.append(_PAD_SLOT_ID)
482
                seq_lens.append(1)
483
                block_tables.append([])
484
                lora_index_mapping.append(0)
485
            batch_size = graph_batch_size
486
            num_decode_tokens = batch_size
487
488
489
490
491
492
493
494

        if use_captured_graph:
            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
            input_block_tables = self.graph_block_tables[:batch_size]
            for i, block_table in enumerate(block_tables):
                if block_table:
                    input_block_tables[i, :len(block_table)] = block_table
495
            block_tables = torch.tensor(input_block_tables, device=self.device)
496
        else:
497
498
            max_block_table_len = max(
                len(block_table) for block_table in block_tables)
499
            block_tables = make_tensor_with_pad(
500
                block_tables,
501
                max_len=max_block_table_len,
502
503
                pad=0,
                dtype=torch.int,
504
                device=self.device,
505
            )
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
        assert max_query_len > 0, ("query_lens: {}".format(query_lens))

        context_lens_tensor = torch.tensor(context_lens,
                                           dtype=torch.int,
                                           device=self.device)

        if multi_modal_input_list:
            assert self.vision_language_config, (
                "Multi-modal inputs are only supported by "
                "vision language models.")
            multi_modal_input = torch.cat(multi_modal_input_list,
                                          dim=0).to(self.device)
        else:
            multi_modal_input = None

        query_lens_tensor = torch.tensor(query_lens,
                                         dtype=torch.long,
                                         device=self.device)
        query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
                                      dtype=torch.int32,
                                      device=self.device)

        seq_lens_tensor = torch.tensor(seq_lens,
                                       dtype=torch.int,
                                       device=self.device)
        seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
                                    dtype=torch.int32,
                                    device=self.device)

        torch.cumsum(query_lens_tensor,
                     dim=0,
                     dtype=query_start_loc.dtype,
                     out=query_start_loc[1:])

        torch.cumsum(seq_lens_tensor,
                     dim=0,
                     dtype=seq_start_loc.dtype,
                     out=seq_start_loc[1:])

        input_tokens_tensor = torch.tensor(input_tokens,
                                           dtype=torch.long,
                                           device=self.device)
        input_positions_tensor = torch.tensor(input_positions,
                                              dtype=torch.long,
                                              device=self.device)
        slot_mapping_tensor = torch.tensor(slot_mapping,
                                           dtype=torch.long,
                                           device=self.device)
554

555
        if self.attn_backend.get_name() == "flashinfer":
556
557
558
559
560
            if not hasattr(self, "flashinfer_workspace_buffer"):
                # Allocate 16MB workspace buffer
                # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
                self.flashinfer_workspace_buffer = torch.empty(
                    16 * 1024 * 1024, dtype=torch.uint8, device=self.device)
561
            paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr,
562
563
                                                  dtype=torch.int,
                                                  device=self.device)
564
565
566
567
568
            paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
                                                   dtype=torch.int,
                                                   device=self.device)
            paged_kv_last_page_len_tensor = torch.tensor(
                paged_kv_last_page_len, dtype=torch.int, device=self.device)
569
570
571
            kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
                                                      self.model_config.dtype)
            attn_metadata = self.attn_backend.make_metadata(
572
573
574
575
                num_prefills=num_prefills,
                slot_mapping=slot_mapping_tensor,
                num_prefill_tokens=num_prefill_tokens,
                num_decode_tokens=num_decode_tokens,
576
                use_cuda_graph=False,
577
578
                max_prefill_seq_len=max_prefill_seq_len,
                block_tables=block_tables,
579
                workspace_buffer=self.flashinfer_workspace_buffer,
580
581
582
                paged_kv_indptr=paged_kv_indptr_tensor,
                paged_kv_indices=paged_kv_indices_tensor,
                paged_kv_last_page_len=paged_kv_last_page_len_tensor,
583
584
585
586
587
                num_qo_heads=self.model_config.get_num_attention_heads(
                    self.parallel_config),
                num_kv_heads=self.model_config.get_num_kv_heads(
                    self.parallel_config),
                head_dim=self.model_config.get_head_size(),
588
589
                page_size=16,
                seq_start_loc=seq_start_loc,
590
591
592
                data_type=kv_cache_dtype)
        else:
            attn_metadata = self.attn_backend.make_metadata(
593
594
595
596
597
                num_prefills=num_prefills,
                slot_mapping=slot_mapping_tensor,
                num_prefill_tokens=num_prefill_tokens,
                num_decode_tokens=num_decode_tokens,
                seq_lens=seq_lens,
598
                seq_lens_tensor=seq_lens_tensor,
599
600
601
602
603
604
                max_query_len=max_query_len,
                max_prefill_seq_len=max_prefill_seq_len,
                max_decode_seq_len=max_decode_seq_len,
                query_start_loc=query_start_loc,
                seq_start_loc=seq_start_loc,
                context_lens_tensor=context_lens_tensor,
605
606
607
                block_tables=block_tables,
                use_cuda_graph=use_captured_graph,
            )
608
609
610
611
612
613
614
615
616
617
618
619

        if self.lora_config:
            lora_mapping = LoRAMapping(
                lora_index_mapping,
                lora_prompt_mapping,
            )
        else:
            lora_mapping = None

        return ModelInput(
            input_tokens=input_tokens_tensor,
            input_positions=input_positions_tensor,
620
            attn_metadata=attn_metadata,
621
622
623
            seq_lens=seq_lens,
            query_lens=query_lens,
            lora_mapping=lora_mapping,
624
            lora_requests=lora_requests,
625
626
627
628
629
            multi_modal_input=multi_modal_input,
            slot_mapping=slot_mapping_tensor,
            num_prefill_tokens=num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            num_prefills=num_prefills,
630
        )
631

632
633
    def prepare_input_tensors(
        self,
634
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
635
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
636
               Set[LoRARequest], LoRAMapping, torch.Tensor]:
637
        if self.is_driver_worker:
638
            assert seq_group_metadata_list is not None
639
            # Prepare input tensors.
640
641
642
            (
                input_tokens,
                input_positions,
643
                attn_metadata,
644
645
                seq_lens,
                query_lens,
646
                lora_mapping,
647
648
649
                lora_requests,
                multi_modal_input,
                slot_mapping,
650
651
652
653
                num_prefill_tokens,
                num_decode_tokens,
                num_prefills,
            ) = self._prepare_model_input(seq_group_metadata_list)
654
            sampling_metadata = SamplingMetadata.prepare(
655
656
                seq_group_metadata_list, seq_lens, query_lens, self.device,
                self.pin_memory)
657

658
659
660
661
662
            metadata_dict = {
                "input_tokens": input_tokens,
                "input_positions": input_positions,
                "selected_token_indices":
                sampling_metadata.selected_token_indices,
663
664
                "lora_requests": lora_requests,
                "lora_mapping": lora_mapping,
665
                "multi_modal_input": multi_modal_input,
666
667
668
669
                "num_prefill_tokens": num_prefill_tokens,
                "num_decode_tokens": num_decode_tokens,
                "slot_mapping": slot_mapping,
                "num_prefills": num_prefills,
670
            }
671
672
            if attn_metadata:
                metadata_dict.update(attn_metadata.asdict_zerocopy())
673
            broadcast_tensor_dict(metadata_dict, src=0)
674
        else:
675
            metadata_dict = broadcast_tensor_dict(src=0)
676
677
678
679
680
681
            input_tokens = metadata_dict.pop("input_tokens")
            input_positions = metadata_dict.pop("input_positions")
            selected_token_indices = metadata_dict.pop(
                "selected_token_indices")
            lora_mapping = metadata_dict.pop("lora_mapping")
            lora_requests = metadata_dict.pop("lora_requests")
682
            multi_modal_input = metadata_dict.pop("multi_modal_input")
683
684
            if metadata_dict:
                attn_metadata = self.attn_backend.make_metadata(
685
686
                    **metadata_dict)
            else:
687
                attn_metadata = None
688
689
            sampling_metadata = SamplingMetadata(
                seq_groups=None,
690
                selected_token_indices=selected_token_indices,
691
                categorized_sample_indices=None,
692
                num_prompts=0,
693
694
            )

695
        return (input_tokens, input_positions, attn_metadata,
696
697
                sampling_metadata, lora_requests, lora_mapping,
                multi_modal_input)
698

699
700
701
    @torch.inference_mode()
    def execute_model(
        self,
702
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
703
        kv_caches: List[torch.Tensor],
704
    ) -> Optional[SamplerOutput]:
705
        (input_tokens, input_positions, attn_metadata, sampling_metadata,
706
707
         lora_requests, lora_mapping, multi_modal_input
         ) = self.prepare_input_tensors(seq_group_metadata_list)
708
709
710
711

        if self.lora_config:
            self.set_active_loras(lora_requests, lora_mapping)

712
713
714
715
        # Currently cuda graph is only supported by the decode phase.
        prefill_meta = attn_metadata.prefill_metadata
        decode_meta = attn_metadata.decode_metadata
        if prefill_meta is None and decode_meta.use_cuda_graph:
716
717
718
719
            graph_batch_size = input_tokens.shape[0]
            model_executable = self.graph_runners[graph_batch_size]
        else:
            model_executable = self.model
720
721
722
723
724
725
726
727
728
        execute_model_kwargs = {
            "input_ids": input_tokens,
            "positions": input_positions,
            "kv_caches": kv_caches,
            "attn_metadata": attn_metadata,
        }
        if self.vision_language_config:
            execute_model_kwargs.update({"image_input": multi_modal_input})
        hidden_states = model_executable(**execute_model_kwargs)
729

730
731
732
733
        # Compute the logits.
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

        # Only perform sampling in the driver worker.
734
        if not self.is_driver_worker:
735
736
            return None

737
738
        # Sample the next token.
        output = self.model.sample(
739
            logits=logits,
740
741
            sampling_metadata=sampling_metadata,
        )
742

743
744
745
746
747
        return output

    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
748
        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
749
750
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs
751
752
753
754
755
756
757
        # This represents the maximum number of different requests
        # that will have unique loras, an therefore the max amount of memory
        # consumption create dummy lora request copies from the lora request
        # passed in, which contains a lora from the lora warmup path.
        dummy_lora_requests = []
        dummy_lora_requests_per_seq = []
        if self.lora_config:
758
            assert self.lora_manager is not None
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
            with self.lora_manager.dummy_lora_cache():
                for idx in range(self.lora_config.max_loras):
                    lora_id = idx + 1
                    dummy_lora_request = LoRARequest(
                        lora_name=f"warmup_{lora_id}",
                        lora_int_id=lora_id,
                        lora_local_path="/not/a/real/path",
                    )
                    self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                     rank=LORA_WARMUP_RANK)
                    dummy_lora_requests.append(dummy_lora_request)
                dummy_lora_requests_per_seq = [
                    dummy_lora_requests[idx % len(dummy_lora_requests)]
                    for idx in range(max_num_seqs)
                ]
774

775
776
777
        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
778
779
780
781
782
783
784
785
786
787
788
        # Additional GPU memory may be needed for vision encoding, which needs
        # to be accounted for when calculating the GPU blocks for
        # vLLM blocker manager.
        # To exercise the worst scenario for GPU memory consumption,
        # the number of seqs (batch_size) is chosen to maximize the number
        # of images processed.
        if self.vision_language_config:
            max_num_seqs = min(
                max_num_seqs,
                int(max_num_batched_tokens /
                    self.vision_language_config.image_feature_size))
789
790
791
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
792
793
            seq_data, fake_multi_modal_input = _prepare_fake_inputs(
                seq_len, self.vision_language_config)
794
795
796
797
798
799
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
800
801
                lora_request=dummy_lora_requests_per_seq[group_id]
                if dummy_lora_requests_per_seq else None,
802
                multi_modal_data=fake_multi_modal_input,
803
804
805
806
807
            )
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
808
        kv_caches = [None] * num_layers
809
        self.execute_model(seqs, kv_caches)
810
        torch.cuda.synchronize()
811
812
        return

813
    def remove_all_loras(self):
814
815
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
816
        self.lora_manager.remove_all_loras()
817

818
    def set_active_loras(self, lora_requests: Set[LoRARequest],
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        self.lora_manager.set_active_loras(lora_requests, lora_mapping)

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_lora(lora_id)

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.list_loras()

839
    @torch.inference_mode()
840
    def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
841
842
843
844
845
846
847
848
849
850
851
852
        """Cuda graph capture a model.

        Note that CUDA graph's performance gain is negligible if number
        of batched tokens are larger than 200. And since CUDA graph
        requires fixed sized tensors, supporting large/variable batch
        size requires high GPU memory overhead. Thus, vLLM only captures
        decoding requests. Mixed batch (chunked prefill + decoding) or
        prefill requests are not captured.

        Since it is used for decoding-only, it assumes there's only 1 token
        per sequence in the batch.
        """
853
854
855
856
857
        assert not self.model_config.enforce_eager
        logger.info("Capturing the model for CUDA graphs. This may lead to "
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
                    "use '--enforce-eager' in the CLI.")
858
859
        logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
                    "If you are running out of memory, consider decreasing "
860
861
862
                    "`gpu_memory_utilization` or enforcing eager mode. "
                    "You can also reduce the `max_num_seqs` as needed "
                    "to decrease memory usage.")
863
864
865
866
        start_time = time.perf_counter()

        # Prepare dummy inputs. These will be reused for all batch sizes.
        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
867
868
869
        input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
870
        slot_mapping.fill_(_PAD_SLOT_ID)
871
        seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
872
873
        block_tables = torch.from_numpy(self.graph_block_tables).cuda()

874
875
876
877
878
879
        graph_batch_size = _get_graph_batch_size(
            self.scheduler_config.max_num_seqs)
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

880
        with graph_capture() as graph_capture_context:
881
882
            # NOTE: Capturing the largest batch size first may help reduce the
            # memory usage of CUDA graph.
883
            for batch_size in reversed(batch_size_capture_list):
884
                # Create dummy attn_metadata.
885
886
887
888
889
                attn_metadata = self.attn_backend.make_metadata(
                    num_prefills=0,
                    num_prefill_tokens=0,
                    num_decode_tokens=batch_size,
                    slot_mapping=slot_mapping[:batch_size],
890
891
892
                    seq_lens=None,
                    seq_lens_tensor=seq_lens[:batch_size],
                    max_query_len=None,
893
894
895
                    max_prefill_seq_len=0,
                    max_decode_seq_len=self.max_seq_len_to_capture,
                    query_start_loc=None,
896
                    seq_start_loc=None,
897
                    context_lens_tensor=None,
898
899
                    block_tables=block_tables[:batch_size],
                    use_cuda_graph=True,
900
                )
901

902
903
904
905
906
907
908
909
910
911
912
913
                if self.lora_config:
                    lora_mapping = LoRAMapping(
                        [0] * batch_size,
                        [0] * batch_size,
                    )
                    self.set_active_loras(set(), lora_mapping)

                graph_runner = CUDAGraphRunner(self.model)
                graph_runner.capture(
                    input_tokens[:batch_size],
                    input_positions[:batch_size],
                    kv_caches,
914
                    attn_metadata,
915
                    memory_pool=self.graph_memory_pool,
916
                    stream=graph_capture_context.stream,
917
                )
918
919
                self.graph_memory_pool = graph_runner.graph.pool()
                self.graph_runners[batch_size] = graph_runner
920
921
922
923

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        # This usually takes < 10 seconds.
924
        logger.info("Graph capturing finished in %.0f secs.", elapsed_time)
925

926
927
928
929
    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

930
931
932
933
934
935
936
937

class CUDAGraphRunner:

    def __init__(self, model: nn.Module):
        self.model = model
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

938
939
940
941
942
943
944
        self._graph: Optional[torch.cuda.CUDAGraph] = None

    @property
    def graph(self):
        assert self._graph is not None
        return self._graph

945
946
947
948
    def capture(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
949
950
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
951
952
        memory_pool: Optional[Tuple[int, int]],
        stream: torch.cuda.Stream,
953
        **kwargs,
954
    ) -> None:
955
        assert self._graph is None
956
957
958
        # Run the model once without capturing the graph.
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
959
960
961
962
963
964
965
966
967
968
969
970
971
        self.model(
            input_ids,
            positions,
            kv_caches,
            attn_metadata,
            **kwargs,
        )
        torch.cuda.synchronize()

        # Capture the graph.
        self._graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
            hidden_states = self.model(
972
973
974
                input_ids,
                positions,
                kv_caches,
975
                attn_metadata,
976
                **kwargs,
977
978
979
980
981
982
983
984
            )
        torch.cuda.synchronize()

        # Save the input and output buffers.
        self.input_buffers = {
            "input_ids": input_ids,
            "positions": positions,
            "kv_caches": kv_caches,
985
            "slot_mapping": attn_metadata.slot_mapping,
986
            "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
987
            "block_tables": attn_metadata.decode_metadata.block_tables,
988
989
990
991
992
993
994
995
        }
        self.output_buffers = {"hidden_states": hidden_states}
        return

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
996
997
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
998
        **kwargs,
999
1000
1001
1002
1003
    ) -> torch.Tensor:
        # KV caches are fixed tensors, so we don't need to copy them.
        del kv_caches

        # Copy the input tensors to the input buffers.
1004
1005
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
        self.input_buffers["positions"].copy_(positions, non_blocking=True)
1006
        self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
1007
                                                 non_blocking=True)
1008
1009
        self.input_buffers["seq_lens_tensor"].copy_(
            attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
1010
1011
        self.input_buffers["block_tables"].copy_(
            attn_metadata.decode_metadata.block_tables, non_blocking=True)
1012
1013
1014
1015
1016
1017
1018
1019
1020
        # Run the graph.
        self.graph.replay()

        # Return the output tensor.
        return self.output_buffers["hidden_states"]

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

1021

1022
def _get_graph_batch_size(batch_size: int) -> int:
1023
1024
1025
1026
1027
    """Returns the padded batch size given actual batch size.

    Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
    2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
    """
1028
1029
1030
1031
1032
    if batch_size <= 2:
        return batch_size
    elif batch_size <= 4:
        return 4
    else:
1033
1034
        return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
                _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052


def _prepare_fake_inputs(
        seq_len: int, vision_language_config: Optional[VisionLanguageConfig]):
    """Prepare fake inputs for profile run."""
    if vision_language_config:
        prompt_tokens = [
            vision_language_config.image_token_id
        ] * vision_language_config.image_feature_size + [0] * (
            seq_len - vision_language_config.image_feature_size)
        fake_image_input = MultiModalData(
            type=MultiModalData.Type.IMAGE,
            data=torch.zeros(vision_language_config.image_input_shape,
                             dtype=torch.float16))
    else:
        prompt_tokens = [0] * seq_len
        fake_image_input = None
    return SequenceData(prompt_tokens), fake_image_input
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064


def _is_block_tables_empty(block_tables: Union[None, Dict]):
    """
    Check if block_tables is None or a dictionary with all None values.
    """
    if block_tables is None:
        return True
    if isinstance(block_tables, dict) and all(
            value is None for value in block_tables.values()):
        return True
    return False