model_runner.py 45.8 KB
Newer Older
1
import time
2
import warnings
3
from collections import defaultdict
4
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
5

6
import numpy as np
7
import torch
8
import torch.nn as nn
9

10
from vllm.attention import AttentionMetadata, get_attn_backend
11
12
13
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
                         ModelConfig, ParallelConfig, SchedulerConfig,
                         VisionLanguageConfig)
14
from vllm.distributed import broadcast_tensor_dict
15
from vllm.distributed.communication_op import graph_capture
16
from vllm.logger import init_logger
17
18
19
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
20
from vllm.model_executor import SamplingMetadata
21
from vllm.model_executor.model_loader import get_model
22
from vllm.multimodal import MULTIMODAL_REGISTRY
23
from vllm.sampling_params import SamplingParams
24
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
25
26
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
                        is_pin_memory_available, make_tensor_with_pad)
27
28
29
30

logger = init_logger(__name__)

_PAD_SLOT_ID = -1
31
LORA_WARMUP_RANK = 8
32
33
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
34
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
35
36
37
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
    _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
38
_NUM_WARMUP_ITERS = 2
39
40


41
42
43
44
class ModelInput(NamedTuple):
    input_tokens: torch.Tensor
    input_positions: torch.Tensor
    attn_metadata: Optional[AttentionMetadata]
45
46
    seq_lens: List[int]
    query_lens: List[int]
47
    lora_mapping: Optional[LoRAMapping]
48
    lora_requests: Set[LoRARequest]
49
    multi_modal_kwargs: Dict[str, torch.Tensor]
50
51
52
53
    slot_mapping: torch.Tensor
    num_prefill_tokens: int
    num_decode_tokens: int
    num_prefills: int
54
55

    @classmethod
56
57
58
59
    def empty(cls, device):
        return ModelInput(
            input_tokens=torch.empty(0, device=device),
            input_positions=torch.empty(0, device=device),
60
            attn_metadata=None,
61
62
            seq_lens=[],
            query_lens=[],
63
            lora_mapping=None,
64
            lora_requests=set(),
65
            multi_modal_kwargs={},
66
67
68
69
            slot_mapping=torch.empty(0, device=device),
            num_prefill_tokens=0,
            num_decode_tokens=0,
            num_prefills=0,
70
71
72
        )


73
74
75
76
77
78
79
class ModelRunner:

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
80
        device_config: DeviceConfig,
81
        cache_config: CacheConfig,
82
        load_config: LoadConfig,
83
        lora_config: Optional[LoRAConfig],
84
        kv_cache_dtype: Optional[str] = "auto",
85
        is_driver_worker: bool = False,
86
        vision_language_config: Optional[VisionLanguageConfig] = None,
87
88
89
90
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
91
92
        self.device_config = device_config
        self.cache_config = cache_config
93
        self.lora_config = lora_config
94
        self.load_config = load_config
95
        self.is_driver_worker = is_driver_worker
96
        self.vision_language_config = vision_language_config
97

98
        self.device = self.device_config.device
99
        self.pin_memory = is_pin_memory_available()
100

101
102
103
104
        self.kv_cache_dtype = kv_cache_dtype
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
105
        self.graph_runners: Dict[int, CUDAGraphRunner] = {}
106
107
108
        self.graph_memory_pool: Optional[Tuple[
            int, int]] = None  # Set during graph capture.
        # When using CUDA graph, the input block tables must be padded to
109
        # max_seq_len_to_capture. However, creating the block table in
110
111
112
113
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
        # (max batch size to capture, max context len to capture / block size).
114
115
116
        self.graph_block_tables = np.zeros(
            (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
            dtype=np.int32)
117
118
119
120
121
122
123
124
125
        self.attn_backend = get_attn_backend(
            self.model_config.get_num_attention_heads(self.parallel_config),
            self.model_config.get_head_size(),
            self.model_config.get_num_kv_heads(self.parallel_config),
            self.model_config.get_sliding_window(),
            self.model_config.dtype,
            self.kv_cache_dtype,
            self.block_size,
        )
126

127
128
129
130
131
132
133
134
135
136
        # Create processor for multi-modal data
        if self.vision_language_config is not None:
            self.multi_modal_input_processor = MULTIMODAL_REGISTRY \
                .create_input_processor(
                    self.model_config,
                    self.vision_language_config,
                )
        else:
            self.multi_modal_input_processor = None

137
        # Lazy initialization
138
        self.model: nn.Module  # Set after load_model
139
140
        # Set if the backend is flashinfer.
        self.flashinfer_workspace_buffer: torch.Tensor
141
142
        # Set after load_model.
        self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
143

144
    def load_model(self) -> None:
145
        with CudaMemoryProfiler() as m:
146
            self.model = get_model(
147
148
149
                model_config=self.model_config,
                device_config=self.device_config,
                load_config=self.load_config,
150
151
152
                lora_config=self.lora_config,
                vision_language_config=self.vision_language_config,
                parallel_config=self.parallel_config,
153
                scheduler_config=self.scheduler_config,
154
                cache_config=self.cache_config,
155
            )
156
157

        self.model_memory_usage = m.consumed_memory
158
159
        logger.info("Loading model weights took %.4f GB",
                    self.model_memory_usage / float(2**30))
160
161

        if self.lora_config:
162
163
164
            assert hasattr(self.model, "supported_lora_modules"
                           ) and self.model.supported_lora_modules, (
                               "Model does not support LoRA")
Terry's avatar
Terry committed
165
166
167
168
169
            assert hasattr(
                self.model,
                "embedding_modules"), "Model does not have embedding_modules"
            assert hasattr(self.model, "embedding_padding_modules"
                           ), "Model does not have embedding_padding_modules"
170
171
            self.lora_manager = LRUCacheWorkerLoRAManager(
                self.scheduler_config.max_num_seqs,
172
173
174
175
176
177
178
179
180
                self.scheduler_config.max_num_batched_tokens,
                self.vocab_size,
                self.lora_config,
                self.device,
                self.model.embedding_modules,
                self.model.embedding_padding_modules,
                max_position_embeddings=self.model.config.
                max_position_embeddings,
            )
181
            self.model = self.lora_manager.create_lora_manager(self.model)
182

183
        if self.kv_cache_dtype == "fp8" and is_hip():
184
185
186
            # Currently only ROCm accepts kv-cache scaling factors
            # via quantization_param_path and this will be deprecated
            # in the future.
187
188
            if self.model_config.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
189
190
191
192
193
194
                    warnings.warn(
                        "Loading kv cache scaling factor from JSON is "
                        "deprecated and will be removed. Please include "
                        "kv cache scaling factors in the model checkpoint.",
                        FutureWarning,
                        stacklevel=2)
195
196
                    self.model.load_kv_cache_scales(
                        self.model_config.quantization_param_path)
197
198
                    logger.info("Loaded KV cache scaling factors from %s",
                                self.model_config.quantization_param_path)
199
                else:
200
201
202
203
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__)
204
            else:
205
206
207
208
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!")
209

210
211
212
213
214
215
216
217
218
219
220
221
222
223
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        from vllm.model_executor.model_loader.loader import ShardedStateLoader
        ShardedStateLoader.save_model(
            self.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

224
225
    def get_max_block_per_batch(self) -> int:
        block_size = self.block_size
226
        return (self.max_seq_len_to_capture + block_size - 1) // block_size
227

228
    def _prepare_model_input(
229
230
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
231
232
233
234
235
236
237
238
239
240
241
242
243
    ) -> ModelInput:
        """Prepare the model input based on a given sequence group.

        The API assumes seq_group_metadata_list is sorted by prefill -> decode.

        The result tensors and data structure also batches input in prefill
        -> decode order. For example,

        - input_tokens[:num_prefill_tokens] contains prefill tokens.
        - input_tokens[num_prefill_tokens:] contains decode tokens.

        If cuda graph is required, this API automatically pads inputs.
        """
244
245
246
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
247
248
249
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
250

251
        seq_lens: List[int] = []
252
253
        prefill_seq_lens: List[int] = []
        decode_seq_lens: List[int] = []
254
        context_lens: List[int] = []
255
        query_lens: List[int] = []
256
        block_tables: List[List[int]] = []
257
258
        multi_modal_kwargs_list: Dict[str,
                                      List[torch.Tensor]] = defaultdict(list)
259
260
261
262
        decode_only = True
        num_prefills = 0
        num_prefill_tokens = 0
        num_decode_tokens = 0
263

264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
        # The following fields are only for flashinfer
        # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
        # for the precise definition of the following fields.
        # An example:
        # request 1, page indices [0, 5, 8]
        # request 2, page indices [1, 6, 7]
        # request 3, page indices [3, 4]
        # paged_kv_indices is a concatenation of page indices of all requests:
        # [0, 5, 8, 1, 6, 7, 3, 4]
        # paged_kv_indptr is used to index into paged_kv_indices:
        # [0, 3, 6, 8]
        paged_kv_indices: List[int] = []
        # 0 at the beginning of paged_kv_indptr indicates the start of the
        # first request’s page indices in the paged_kv_indices list.
        paged_kv_indptr: List[int] = [0]
        # paged_kv_last_page_len is the length of the last page of each request
        paged_kv_last_page_len: List[int] = []

282
        if len(seq_group_metadata_list) == 0:
283
            return ModelInput.empty(self.device)
284

285
286
287
288
289
290
        if self.sliding_window is not None:
            sliding_window_blocks = (self.sliding_window + self.block_size -
                                     1) // self.block_size
            block_aligned_sliding_window = \
                sliding_window_blocks * self.block_size

291
292
        for seq_group_metadata in seq_group_metadata_list:
            seq_ids = list(seq_group_metadata.seq_data.keys())
293
            is_prompt = seq_group_metadata.is_prompt
294

295
            for seq_id in seq_ids:
296
297
298
299
300
301
302
303
304
                computed_block_nums = seq_group_metadata.computed_block_nums
                if (self.scheduler_config is not None
                        and self.scheduler_config.chunked_prefill_enabled
                        and not (computed_block_nums is None
                                 or computed_block_nums == [])):
                    raise RuntimeError(
                        "chunked prefill cannot be used with prefix caching "
                        "now.")

305
                seq_data = seq_group_metadata.seq_data[seq_id]
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
                if is_prompt:
                    context_len = seq_data.get_num_computed_tokens()
                else:
                    # get_num_computed_tokens is incorrect for spec decoding.
                    # So, we should have a special logic here.
                    # TODO(sang): Fix it.
                    context_len = seq_data.get_len() - 1

                seq_len = min(
                    seq_data.get_len(),
                    context_len + seq_group_metadata.token_chunk_size)
                if is_prompt:
                    tokens = seq_data.get_token_ids()[context_len:seq_len]
                else:
                    # Optimization. get_token_ids requires the entire copy of
                    # tokens.
                    tokens = [seq_data.get_last_token_id()]
323

324
325
326
327
328
329
330
                # Prefix cache was hit.
                # Prefix is not supported with sliding_window
                prefix_cache_hit = (computed_block_nums is not None
                                    and len(computed_block_nums) > 0
                                    and self.sliding_window is None
                                    and is_prompt)

331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
                # These are seq_len/context_len capped to the sliding window.
                # They are passed to decode kernel.
                # We still need original seq_len/context_len to compute slot
                # mapping (and input position) below.
                curr_sliding_window_blocks = None
                sliding_seq_len = seq_len
                sliding_context_len = context_len

                # TODO(sang): This is a hack to make sliding window work with
                # paged attn. We can remove it if we make paged attn kernel
                # to properly handle slinding window attn.
                if (self.sliding_window is not None and not is_prompt):
                    curr_sliding_window_blocks = sliding_window_blocks
                    if self.scheduler_config.use_v2_block_manager:
                        # number of elements in last block
                        suff_len = seq_len % self.block_size
                        sliding_seq_len = min(
                            seq_len, block_aligned_sliding_window + suff_len)
                        if suff_len > 0:
                            curr_sliding_window_blocks += 1
                    else:
                        sliding_seq_len = min(seq_len, self.sliding_window)
                    sliding_context_len = sliding_seq_len - 1

355
356
357
358
359
360
361
                # TODO(sang): Combine chunked prefill and prefix caching by
                # only allowing multiple of block_size chunk size.
                # NOTE: This only works for oooooooxxx style attention.
                if prefix_cache_hit:
                    assert computed_block_nums is not None
                    context_len = len(computed_block_nums) * self.block_size
                    tokens = tokens[context_len:]
362
363
364
365
366
367
368

                    # need to think what to set it to when we have both sliding
                    # window and prefix caching...
                    assert self.sliding_window is None, \
                        "Prefix caching is not supported with sliding window"
                    sliding_context_len = context_len

369
370
371
372
373
374
375
376
377
378
379
380
381
                    if self.attn_backend.get_name() == "flash-attn":
                        # NOTE(woosuk): For flash-attn, the block table should
                        # include the entries for the incoming prefill tokens.
                        # TODO(woosuk): This is a temporary fix. We should
                        # provide a unified interface for different backends.
                        block_table = seq_group_metadata.block_tables[seq_id]
                    else:
                        block_table = computed_block_nums
                elif (self.scheduler_config.chunked_prefill_enabled
                      or not is_prompt):
                    if seq_group_metadata.block_tables is not None:
                        # chunked prefill or decode
                        block_table = seq_group_metadata.block_tables[seq_id]
382
383
384
                        if curr_sliding_window_blocks is not None:
                            block_table = block_table[
                                -curr_sliding_window_blocks:]
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
                        if self.attn_backend.get_name() == "flashinfer":
                            paged_kv_indices.extend(block_table)
                            paged_kv_indptr.append(paged_kv_indptr[-1] +
                                                   len(block_table))
                            last_page_len = seq_data.get_len(
                            ) % self.block_size
                            if last_page_len == 0:
                                last_page_len = self.block_size
                            paged_kv_last_page_len.append(last_page_len)
                    else:
                        # Only happens when memory profiling runs.
                        block_table = []
                else:
                    # Prefill without chunked prefill or memory profiling.
                    block_table = []
                block_tables.append(block_table)

402
403
404
                seq_lens.append(sliding_seq_len)
                context_lens.append(sliding_context_len)
                query_len = sliding_seq_len - sliding_context_len
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
                query_lens.append(query_len)
                input_tokens.extend(tokens)
                input_positions.extend(list(range(context_len, seq_len)))
                lora_id = seq_group_metadata.lora_int_id

                if is_prompt:
                    assert len(seq_ids) == 1
                    num_prefills += 1
                    num_prefill_tokens += len(tokens)
                    decode_only = False
                    prefill_seq_lens.append(seq_len)
                else:
                    assert query_len == 1, (
                        "seq_len: {}, context_len: {}, query_len: {}".format(
                            seq_len, context_len, query_len))
                    num_decode_tokens += query_len
421
                    decode_seq_lens.append(sliding_seq_len)
422
423
424
425

                if lora_id > 0:
                    lora_requests.add(seq_group_metadata.lora_request)

426
                lora_index_mapping += [lora_id] * query_len
427
428
                lora_prompt_mapping.extend(
                    [lora_id] *
429
                    (query_len if seq_group_metadata.sampling_params
430
                     and seq_group_metadata.sampling_params.prompt_logprobs
431
                     is not None else 1))
432

433
434
435
436
437
438
439
440
441
442
443
                mm_data = seq_group_metadata.multi_modal_data
                if mm_data is not None:
                    # Process multi-modal data
                    if self.multi_modal_input_processor is None:
                        raise ValueError(
                            "Multi-modal inputs are only supported by "
                            "vision language models.")

                    mm_kwargs = self.multi_modal_input_processor(mm_data)
                    for k, v in mm_kwargs.items():
                        multi_modal_kwargs_list[k].append(v)
444
445
446
447
448
449
450
451

                if _is_block_tables_empty(seq_group_metadata.block_tables):
                    # During memory profiling, the block tables are not
                    # initialized yet. In this case, we just use a dummy
                    # slot mapping.
                    # In embeddings, the block tables are {seq_id: None}.
                    slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
                    continue
452

453
                # Compute the slot mapping.
454
455
                block_table = seq_group_metadata.block_tables[seq_id]

456
457
458
459
460
461
462
                # Mask the [0, start_idx) tokens of the prompt with
                # _PAD_SLOT_ID, where start_idx is max(0, seq_len -
                # sliding_window). For example, if the prompt len is 10,
                # sliding window is 8, and block size is 4, the first two
                # tokens are masked and the slot mapping will be
                # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
                start_idx = 0
463
                if self.sliding_window is not None:
464
                    if is_prompt:
465
466
                        assert self.scheduler_config.use_v2_block_manager \
                            or context_len == 0, (
467
                            "Prefix caching is currently not supported with "
468
                            "sliding window attention in V1 block manager")
469
470
471
472
473
474
475
476
477
478
479
480
481
482
                    # It is an optimization. When it is decoding, it is always
                    # 0. When prefill, we use it to not write slots to kv cache
                    # to save memory.
                    start_idx = max(0, query_len - self.sliding_window)

                for i in range(context_len, seq_len):
                    if i < start_idx:
                        slot_mapping.append(_PAD_SLOT_ID)
                        continue

                    block_number = block_table[i // self.block_size]
                    block_offset = i % self.block_size
                    slot = block_number * self.block_size + block_offset
                    slot_mapping.append(slot)
483

484
485
486
487
        batch_size = len(input_tokens)
        max_query_len = max(query_lens)
        max_prefill_seq_len = max(prefill_seq_lens, default=0)
        max_decode_seq_len = max(decode_seq_lens, default=0)
488

489
        # If cuda graph can be used, pad tensors accordingly.
490
        # See `capture_model` API for more details.
491
492
493
494
495
        # vLLM uses cuda graph only for decoding requests.
        use_captured_graph = (
            decode_only and not self.model_config.enforce_eager
            and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
            and max_decode_seq_len <= self.max_seq_len_to_capture)
496
497
498
499
        if use_captured_graph:
            graph_batch_size = _get_graph_batch_size(batch_size)
            assert graph_batch_size >= batch_size
            for _ in range(graph_batch_size - batch_size):
500
501
502
                input_tokens.append(0)
                input_positions.append(0)
                slot_mapping.append(_PAD_SLOT_ID)
503
                seq_lens.append(1)
504
                block_tables.append([])
505
                lora_index_mapping.append(0)
506
            batch_size = graph_batch_size
507
            num_decode_tokens = batch_size
508
509
510
511
512
513
514
515

        if use_captured_graph:
            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
            input_block_tables = self.graph_block_tables[:batch_size]
            for i, block_table in enumerate(block_tables):
                if block_table:
                    input_block_tables[i, :len(block_table)] = block_table
516
            block_tables = torch.tensor(input_block_tables, device=self.device)
517
        else:
518
519
            max_block_table_len = max(
                len(block_table) for block_table in block_tables)
520
            block_tables = make_tensor_with_pad(
521
                block_tables,
522
                max_len=max_block_table_len,
523
524
                pad=0,
                dtype=torch.int,
525
                device=self.device,
526
            )
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
        assert max_query_len > 0, ("query_lens: {}".format(query_lens))

        context_lens_tensor = torch.tensor(context_lens,
                                           dtype=torch.int,
                                           device=self.device)
        query_lens_tensor = torch.tensor(query_lens,
                                         dtype=torch.long,
                                         device=self.device)
        query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
                                      dtype=torch.int32,
                                      device=self.device)

        seq_lens_tensor = torch.tensor(seq_lens,
                                       dtype=torch.int,
                                       device=self.device)
        seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
                                    dtype=torch.int32,
                                    device=self.device)

        torch.cumsum(query_lens_tensor,
                     dim=0,
                     dtype=query_start_loc.dtype,
                     out=query_start_loc[1:])

        torch.cumsum(seq_lens_tensor,
                     dim=0,
                     dtype=seq_start_loc.dtype,
                     out=seq_start_loc[1:])

        input_tokens_tensor = torch.tensor(input_tokens,
                                           dtype=torch.long,
                                           device=self.device)
        input_positions_tensor = torch.tensor(input_positions,
                                              dtype=torch.long,
                                              device=self.device)
        slot_mapping_tensor = torch.tensor(slot_mapping,
                                           dtype=torch.long,
                                           device=self.device)
565

566
        if self.attn_backend.get_name() == "flashinfer":
567
568
569
570
571
            if not hasattr(self, "flashinfer_workspace_buffer"):
                # Allocate 16MB workspace buffer
                # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
                self.flashinfer_workspace_buffer = torch.empty(
                    16 * 1024 * 1024, dtype=torch.uint8, device=self.device)
572
            paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr,
573
574
                                                  dtype=torch.int,
                                                  device=self.device)
575
576
577
578
579
            paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
                                                   dtype=torch.int,
                                                   device=self.device)
            paged_kv_last_page_len_tensor = torch.tensor(
                paged_kv_last_page_len, dtype=torch.int, device=self.device)
580
581
582
            kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
                                                      self.model_config.dtype)
            attn_metadata = self.attn_backend.make_metadata(
583
584
585
586
                num_prefills=num_prefills,
                slot_mapping=slot_mapping_tensor,
                num_prefill_tokens=num_prefill_tokens,
                num_decode_tokens=num_decode_tokens,
587
                use_cuda_graph=False,
588
589
                max_prefill_seq_len=max_prefill_seq_len,
                block_tables=block_tables,
590
                workspace_buffer=self.flashinfer_workspace_buffer,
591
592
593
                paged_kv_indptr=paged_kv_indptr_tensor,
                paged_kv_indices=paged_kv_indices_tensor,
                paged_kv_last_page_len=paged_kv_last_page_len_tensor,
594
595
596
597
598
                num_qo_heads=self.model_config.get_num_attention_heads(
                    self.parallel_config),
                num_kv_heads=self.model_config.get_num_kv_heads(
                    self.parallel_config),
                head_dim=self.model_config.get_head_size(),
599
600
                page_size=16,
                seq_start_loc=seq_start_loc,
601
602
603
                data_type=kv_cache_dtype)
        else:
            attn_metadata = self.attn_backend.make_metadata(
604
605
606
607
608
                num_prefills=num_prefills,
                slot_mapping=slot_mapping_tensor,
                num_prefill_tokens=num_prefill_tokens,
                num_decode_tokens=num_decode_tokens,
                seq_lens=seq_lens,
609
                seq_lens_tensor=seq_lens_tensor,
610
611
612
613
614
615
                max_query_len=max_query_len,
                max_prefill_seq_len=max_prefill_seq_len,
                max_decode_seq_len=max_decode_seq_len,
                query_start_loc=query_start_loc,
                seq_start_loc=seq_start_loc,
                context_lens_tensor=context_lens_tensor,
616
617
618
                block_tables=block_tables,
                use_cuda_graph=use_captured_graph,
            )
619
620
621
622
623
624
625
626
627

        if self.lora_config:
            lora_mapping = LoRAMapping(
                lora_index_mapping,
                lora_prompt_mapping,
            )
        else:
            lora_mapping = None

628
629
630
631
632
        multi_modal_kwargs = {
            k: torch.cat(v, dim=0).to(self.device)
            for k, v in multi_modal_kwargs_list.items()
        }

633
634
635
        return ModelInput(
            input_tokens=input_tokens_tensor,
            input_positions=input_positions_tensor,
636
            attn_metadata=attn_metadata,
637
638
639
            seq_lens=seq_lens,
            query_lens=query_lens,
            lora_mapping=lora_mapping,
640
            lora_requests=lora_requests,
641
            multi_modal_kwargs=multi_modal_kwargs,
642
643
644
645
            slot_mapping=slot_mapping_tensor,
            num_prefill_tokens=num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            num_prefills=num_prefills,
646
        )
647

648
649
    def prepare_input_tensors(
        self,
650
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
651
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
652
               Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]:
653
        if self.is_driver_worker:
654
            assert seq_group_metadata_list is not None
655
            # Prepare input tensors.
656
657
658
            (
                input_tokens,
                input_positions,
659
                attn_metadata,
660
661
                seq_lens,
                query_lens,
662
                lora_mapping,
663
                lora_requests,
664
                multi_modal_kwargs,
665
                slot_mapping,
666
667
668
669
                num_prefill_tokens,
                num_decode_tokens,
                num_prefills,
            ) = self._prepare_model_input(seq_group_metadata_list)
670
            sampling_metadata = SamplingMetadata.prepare(
671
672
                seq_group_metadata_list, seq_lens, query_lens, self.device,
                self.pin_memory)
673

674
675
676
677
678
            metadata_dict = {
                "input_tokens": input_tokens,
                "input_positions": input_positions,
                "selected_token_indices":
                sampling_metadata.selected_token_indices,
679
680
                "lora_requests": lora_requests,
                "lora_mapping": lora_mapping,
681
                "multi_modal_kwargs": multi_modal_kwargs,
682
683
684
685
                "num_prefill_tokens": num_prefill_tokens,
                "num_decode_tokens": num_decode_tokens,
                "slot_mapping": slot_mapping,
                "num_prefills": num_prefills,
686
            }
687
688
            if attn_metadata:
                metadata_dict.update(attn_metadata.asdict_zerocopy())
689
            broadcast_tensor_dict(metadata_dict, src=0)
690
        else:
691
            metadata_dict = broadcast_tensor_dict(src=0)
692
693
694
695
696
697
            input_tokens = metadata_dict.pop("input_tokens")
            input_positions = metadata_dict.pop("input_positions")
            selected_token_indices = metadata_dict.pop(
                "selected_token_indices")
            lora_mapping = metadata_dict.pop("lora_mapping")
            lora_requests = metadata_dict.pop("lora_requests")
698
            multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
699
700
            if metadata_dict:
                attn_metadata = self.attn_backend.make_metadata(
701
702
                    **metadata_dict)
            else:
703
                attn_metadata = None
704
705
            sampling_metadata = SamplingMetadata(
                seq_groups=None,
706
                selected_token_indices=selected_token_indices,
707
                categorized_sample_indices=None,
708
                num_prompts=0,
709
710
            )

711
        return (input_tokens, input_positions, attn_metadata,
712
                sampling_metadata, lora_requests, lora_mapping,
713
                multi_modal_kwargs)
714

715
716
717
    @torch.inference_mode()
    def execute_model(
        self,
718
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
719
        kv_caches: List[torch.Tensor],
720
    ) -> Optional[SamplerOutput]:
721
        (input_tokens, input_positions, attn_metadata, sampling_metadata,
722
         lora_requests, lora_mapping, multi_modal_kwargs
723
         ) = self.prepare_input_tensors(seq_group_metadata_list)
724
725
726
727

        if self.lora_config:
            self.set_active_loras(lora_requests, lora_mapping)

728
729
730
731
        # Currently cuda graph is only supported by the decode phase.
        prefill_meta = attn_metadata.prefill_metadata
        decode_meta = attn_metadata.decode_metadata
        if prefill_meta is None and decode_meta.use_cuda_graph:
732
733
734
735
            graph_batch_size = input_tokens.shape[0]
            model_executable = self.graph_runners[graph_batch_size]
        else:
            model_executable = self.model
736
737
738
739
740
741
742
743

        hidden_states = model_executable(
            input_ids=input_tokens,
            positions=input_positions,
            kv_caches=kv_caches,
            attn_metadata=attn_metadata,
            **multi_modal_kwargs,
        )
744

745
746
747
748
        # Compute the logits.
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

        # Only perform sampling in the driver worker.
749
        if not self.is_driver_worker:
750
751
            return None

752
753
        # Sample the next token.
        output = self.model.sample(
754
            logits=logits,
755
756
            sampling_metadata=sampling_metadata,
        )
757

758
759
760
761
762
        return output

    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
763
        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
764
765
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs
766
767
768
769
770
771
772
        # This represents the maximum number of different requests
        # that will have unique loras, an therefore the max amount of memory
        # consumption create dummy lora request copies from the lora request
        # passed in, which contains a lora from the lora warmup path.
        dummy_lora_requests = []
        dummy_lora_requests_per_seq = []
        if self.lora_config:
773
            assert self.lora_manager is not None
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
            with self.lora_manager.dummy_lora_cache():
                for idx in range(self.lora_config.max_loras):
                    lora_id = idx + 1
                    dummy_lora_request = LoRARequest(
                        lora_name=f"warmup_{lora_id}",
                        lora_int_id=lora_id,
                        lora_local_path="/not/a/real/path",
                    )
                    self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                     rank=LORA_WARMUP_RANK)
                    dummy_lora_requests.append(dummy_lora_request)
                dummy_lora_requests_per_seq = [
                    dummy_lora_requests[idx % len(dummy_lora_requests)]
                    for idx in range(max_num_seqs)
                ]
789

790
791
792
        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
793
794
795
796
797
798
        # Additional GPU memory may be needed for vision encoding, which needs
        # to be accounted for when calculating the GPU blocks for
        # vLLM blocker manager.
        # To exercise the worst scenario for GPU memory consumption,
        # the number of seqs (batch_size) is chosen to maximize the number
        # of images processed.
799
800
801
802
        model_config = self.model_config
        vlm_config = self.vision_language_config

        if vlm_config:
803
804
            max_num_seqs = min(
                max_num_seqs,
805
                int(max_num_batched_tokens / vlm_config.image_feature_size))
806
807
808
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
809
810
811
812
813
814
815
816

            if vlm_config is None:
                seq_data = SequenceData([0] * seq_len)
                dummy_multi_modal_data = None
            else:
                seq_data, dummy_multi_modal_data = MULTIMODAL_REGISTRY \
                    .dummy_data_for_profiling(seq_len, model_config, vlm_config)

817
818
819
820
821
822
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
823
824
                lora_request=dummy_lora_requests_per_seq[group_id]
                if dummy_lora_requests_per_seq else None,
825
                multi_modal_data=dummy_multi_modal_data,
826
827
828
829
830
            )
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
831
        kv_caches = [None] * num_layers
832
        self.execute_model(seqs, kv_caches)
833
        torch.cuda.synchronize()
834
835
        return

836
    def remove_all_loras(self):
837
838
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
839
        self.lora_manager.remove_all_loras()
840

841
    def set_active_loras(self, lora_requests: Set[LoRARequest],
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        self.lora_manager.set_active_loras(lora_requests, lora_mapping)

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_lora(lora_id)

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.list_loras()

862
    @torch.inference_mode()
863
    def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
864
865
866
867
868
869
870
871
872
873
874
875
        """Cuda graph capture a model.

        Note that CUDA graph's performance gain is negligible if number
        of batched tokens are larger than 200. And since CUDA graph
        requires fixed sized tensors, supporting large/variable batch
        size requires high GPU memory overhead. Thus, vLLM only captures
        decoding requests. Mixed batch (chunked prefill + decoding) or
        prefill requests are not captured.

        Since it is used for decoding-only, it assumes there's only 1 token
        per sequence in the batch.
        """
876
877
878
879
880
        assert not self.model_config.enforce_eager
        logger.info("Capturing the model for CUDA graphs. This may lead to "
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
                    "use '--enforce-eager' in the CLI.")
881
882
        logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
                    "If you are running out of memory, consider decreasing "
883
884
885
                    "`gpu_memory_utilization` or enforcing eager mode. "
                    "You can also reduce the `max_num_seqs` as needed "
                    "to decrease memory usage.")
886
887
888
889
        start_time = time.perf_counter()

        # Prepare dummy inputs. These will be reused for all batch sizes.
        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
890
891
892
        input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
893
        slot_mapping.fill_(_PAD_SLOT_ID)
894
        seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
895
896
        block_tables = torch.from_numpy(self.graph_block_tables).cuda()

897
898
899
900
901
902
        graph_batch_size = _get_graph_batch_size(
            self.scheduler_config.max_num_seqs)
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

903
        with graph_capture() as graph_capture_context:
904
905
            # NOTE: Capturing the largest batch size first may help reduce the
            # memory usage of CUDA graph.
906
            for batch_size in reversed(batch_size_capture_list):
907
                # Create dummy attn_metadata.
908
909
910
911
912
                attn_metadata = self.attn_backend.make_metadata(
                    num_prefills=0,
                    num_prefill_tokens=0,
                    num_decode_tokens=batch_size,
                    slot_mapping=slot_mapping[:batch_size],
913
914
915
                    seq_lens=None,
                    seq_lens_tensor=seq_lens[:batch_size],
                    max_query_len=None,
916
917
918
                    max_prefill_seq_len=0,
                    max_decode_seq_len=self.max_seq_len_to_capture,
                    query_start_loc=None,
919
                    seq_start_loc=None,
920
                    context_lens_tensor=None,
921
922
                    block_tables=block_tables[:batch_size],
                    use_cuda_graph=True,
923
                )
924

925
926
927
928
929
930
931
932
933
934
935
936
                if self.lora_config:
                    lora_mapping = LoRAMapping(
                        [0] * batch_size,
                        [0] * batch_size,
                    )
                    self.set_active_loras(set(), lora_mapping)

                graph_runner = CUDAGraphRunner(self.model)
                graph_runner.capture(
                    input_tokens[:batch_size],
                    input_positions[:batch_size],
                    kv_caches,
937
                    attn_metadata,
938
                    memory_pool=self.graph_memory_pool,
939
                    stream=graph_capture_context.stream,
940
                )
941
942
                self.graph_memory_pool = graph_runner.graph.pool()
                self.graph_runners[batch_size] = graph_runner
943
944
945
946

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        # This usually takes < 10 seconds.
947
        logger.info("Graph capturing finished in %.0f secs.", elapsed_time)
948

949
950
951
952
    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

953
954
955
956
957
958
959
960

class CUDAGraphRunner:

    def __init__(self, model: nn.Module):
        self.model = model
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

961
962
963
964
965
966
967
        self._graph: Optional[torch.cuda.CUDAGraph] = None

    @property
    def graph(self):
        assert self._graph is not None
        return self._graph

968
969
970
971
    def capture(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
972
973
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
974
975
        memory_pool: Optional[Tuple[int, int]],
        stream: torch.cuda.Stream,
976
        **kwargs,
977
    ) -> None:
978
        assert self._graph is None
979
        # Run the model a few times without capturing the graph.
980
981
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
982
983
984
985
986
987
988
989
990
        # Note one iteration is not enough for torch.jit.script
        for _ in range(_NUM_WARMUP_ITERS):
            self.model(
                input_ids,
                positions,
                kv_caches,
                attn_metadata,
                **kwargs,
            )
991
992
993
994
995
996
        torch.cuda.synchronize()

        # Capture the graph.
        self._graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
            hidden_states = self.model(
997
998
999
                input_ids,
                positions,
                kv_caches,
1000
                attn_metadata,
1001
                **kwargs,
1002
1003
1004
1005
1006
1007
1008
1009
            )
        torch.cuda.synchronize()

        # Save the input and output buffers.
        self.input_buffers = {
            "input_ids": input_ids,
            "positions": positions,
            "kv_caches": kv_caches,
1010
            "slot_mapping": attn_metadata.slot_mapping,
1011
            "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
1012
            "block_tables": attn_metadata.decode_metadata.block_tables,
1013
1014
1015
1016
1017
1018
1019
1020
        }
        self.output_buffers = {"hidden_states": hidden_states}
        return

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1021
1022
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1023
        **kwargs,
1024
1025
1026
1027
1028
    ) -> torch.Tensor:
        # KV caches are fixed tensors, so we don't need to copy them.
        del kv_caches

        # Copy the input tensors to the input buffers.
1029
1030
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
        self.input_buffers["positions"].copy_(positions, non_blocking=True)
1031
        self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
1032
                                                 non_blocking=True)
1033
1034
        self.input_buffers["seq_lens_tensor"].copy_(
            attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
1035
1036
        self.input_buffers["block_tables"].copy_(
            attn_metadata.decode_metadata.block_tables, non_blocking=True)
1037
1038
1039
1040
1041
1042
1043
1044
1045
        # Run the graph.
        self.graph.replay()

        # Return the output tensor.
        return self.output_buffers["hidden_states"]

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

1046

1047
def _get_graph_batch_size(batch_size: int) -> int:
1048
1049
1050
1051
1052
    """Returns the padded batch size given actual batch size.

    Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
    2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
    """
1053
1054
1055
1056
1057
    if batch_size <= 2:
        return batch_size
    elif batch_size <= 4:
        return 4
    else:
1058
1059
        return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
                _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
1060
1061


1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
def _is_block_tables_empty(block_tables: Union[None, Dict]):
    """
    Check if block_tables is None or a dictionary with all None values.
    """
    if block_tables is None:
        return True
    if isinstance(block_tables, dict) and all(
            value is None for value in block_tables.values()):
        return True
    return False