model_runner.py 47.7 KB
Newer Older
1
import gc
2
import time
3
import warnings
4
from collections import defaultdict
5
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
6

7
import numpy as np
8
import torch
9
import torch.nn as nn
10

11
from vllm.attention import AttentionMetadata, get_attn_backend
12
13
14
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
                         ModelConfig, ParallelConfig, SchedulerConfig,
                         VisionLanguageConfig)
15
from vllm.distributed import broadcast_tensor_dict
16
from vllm.distributed.parallel_state import graph_capture
17
from vllm.logger import init_logger
18
19
20
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
21
from vllm.model_executor import SamplingMetadata
22
from vllm.model_executor.model_loader import get_model
23
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
24
from vllm.multimodal import MULTIMODAL_REGISTRY
25
from vllm.sampling_params import SamplingParams
26
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
27
28
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
                        is_pin_memory_available, make_tensor_with_pad)
29
30
31
32

logger = init_logger(__name__)

_PAD_SLOT_ID = -1
33
LORA_WARMUP_RANK = 8
34
35
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
36
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
37
38
39
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
    _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
40
_NUM_WARMUP_ITERS = 2
41
42


43
44
45
46
class ModelInput(NamedTuple):
    input_tokens: torch.Tensor
    input_positions: torch.Tensor
    attn_metadata: Optional[AttentionMetadata]
47
48
    seq_lens: List[int]
    query_lens: List[int]
49
    lora_mapping: Optional[LoRAMapping]
50
    lora_requests: Set[LoRARequest]
51
    multi_modal_kwargs: Dict[str, torch.Tensor]
52
53
54
55
    slot_mapping: torch.Tensor
    num_prefill_tokens: int
    num_decode_tokens: int
    num_prefills: int
56
57

    @classmethod
58
59
60
61
    def empty(cls, device):
        return ModelInput(
            input_tokens=torch.empty(0, device=device),
            input_positions=torch.empty(0, device=device),
62
            attn_metadata=None,
63
64
            seq_lens=[],
            query_lens=[],
65
            lora_mapping=None,
66
            lora_requests=set(),
67
            multi_modal_kwargs={},
68
69
70
71
            slot_mapping=torch.empty(0, device=device),
            num_prefill_tokens=0,
            num_decode_tokens=0,
            num_prefills=0,
72
73
74
        )


75
76
77
78
79
80
81
class ModelRunner:

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
82
        device_config: DeviceConfig,
83
        cache_config: CacheConfig,
84
        load_config: LoadConfig,
85
        lora_config: Optional[LoRAConfig],
86
        kv_cache_dtype: Optional[str] = "auto",
87
        is_driver_worker: bool = False,
88
        vision_language_config: Optional[VisionLanguageConfig] = None,
89
        return_hidden_states: bool = False,
90
91
92
93
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
94
95
        self.device_config = device_config
        self.cache_config = cache_config
96
        self.lora_config = lora_config
97
        self.load_config = load_config
98
        self.is_driver_worker = is_driver_worker
99
        self.vision_language_config = vision_language_config
100
        self.return_hidden_states = return_hidden_states
101

102
        self.device = self.device_config.device
103
        self.pin_memory = is_pin_memory_available()
104

105
106
107
108
        self.kv_cache_dtype = kv_cache_dtype
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
109
        self.graph_runners: Dict[int, CUDAGraphRunner] = {}
110
111
112
        self.graph_memory_pool: Optional[Tuple[
            int, int]] = None  # Set during graph capture.
        # When using CUDA graph, the input block tables must be padded to
113
        # max_seq_len_to_capture. However, creating the block table in
114
115
116
117
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
        # (max batch size to capture, max context len to capture / block size).
118
119
120
        self.graph_block_tables = np.zeros(
            (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
            dtype=np.int32)
121
122
        num_attn_heads = self.model_config.get_num_attention_heads(
            self.parallel_config)
123
        self.attn_backend = get_attn_backend(
124
            num_attn_heads,
125
126
127
128
129
130
            self.model_config.get_head_size(),
            self.model_config.get_num_kv_heads(self.parallel_config),
            self.model_config.get_sliding_window(),
            self.model_config.dtype,
            self.kv_cache_dtype,
            self.block_size,
131
        ) if num_attn_heads else None
132

133
134
135
136
137
138
139
140
141
142
        # Create processor for multi-modal data
        if self.vision_language_config is not None:
            self.multi_modal_input_processor = MULTIMODAL_REGISTRY \
                .create_input_processor(
                    self.model_config,
                    self.vision_language_config,
                )
        else:
            self.multi_modal_input_processor = None

143
        # Lazy initialization
144
        self.model: nn.Module  # Set after load_model
145
146
        # Set if the backend is flashinfer.
        self.flashinfer_workspace_buffer: torch.Tensor
147
148
        # Set after load_model.
        self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
149

150
    def load_model(self) -> None:
151
        with CudaMemoryProfiler() as m:
152
            self.model = get_model(
153
154
155
                model_config=self.model_config,
                device_config=self.device_config,
                load_config=self.load_config,
156
157
158
                lora_config=self.lora_config,
                vision_language_config=self.vision_language_config,
                parallel_config=self.parallel_config,
159
                scheduler_config=self.scheduler_config,
160
                cache_config=self.cache_config,
161
            )
162
163

        self.model_memory_usage = m.consumed_memory
164
165
        logger.info("Loading model weights took %.4f GB",
                    self.model_memory_usage / float(2**30))
166
167

        if self.lora_config:
168
169
170
            assert hasattr(self.model, "supported_lora_modules"
                           ) and self.model.supported_lora_modules, (
                               "Model does not support LoRA")
Terry's avatar
Terry committed
171
172
173
174
175
            assert hasattr(
                self.model,
                "embedding_modules"), "Model does not have embedding_modules"
            assert hasattr(self.model, "embedding_padding_modules"
                           ), "Model does not have embedding_padding_modules"
176
177
            self.lora_manager = LRUCacheWorkerLoRAManager(
                self.scheduler_config.max_num_seqs,
178
179
180
181
182
183
184
185
186
                self.scheduler_config.max_num_batched_tokens,
                self.vocab_size,
                self.lora_config,
                self.device,
                self.model.embedding_modules,
                self.model.embedding_padding_modules,
                max_position_embeddings=self.model.config.
                max_position_embeddings,
            )
187
            self.model = self.lora_manager.create_lora_manager(self.model)
188

189
        if self.kv_cache_dtype == "fp8" and is_hip():
190
191
192
            # Currently only ROCm accepts kv-cache scaling factors
            # via quantization_param_path and this will be deprecated
            # in the future.
193
194
            if self.model_config.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
195
196
197
198
199
200
                    warnings.warn(
                        "Loading kv cache scaling factor from JSON is "
                        "deprecated and will be removed. Please include "
                        "kv cache scaling factors in the model checkpoint.",
                        FutureWarning,
                        stacklevel=2)
201
202
                    self.model.load_kv_cache_scales(
                        self.model_config.quantization_param_path)
203
204
                    logger.info("Loaded KV cache scaling factors from %s",
                                self.model_config.quantization_param_path)
205
                else:
206
207
208
209
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__)
210
            else:
211
212
213
214
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!")
215

216
217
218
219
220
221
222
223
224
225
226
227
228
229
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        from vllm.model_executor.model_loader.loader import ShardedStateLoader
        ShardedStateLoader.save_model(
            self.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

230
231
232
233
234
235
236
237
238
239
    def save_tensorized_model(
        self,
        tensorizer_config: TensorizerConfig,
    ) -> None:
        from vllm.model_executor.model_loader.loader import TensorizerLoader
        TensorizerLoader.save_model(
            self.model,
            tensorizer_config=tensorizer_config,
        )

240
241
    def get_max_block_per_batch(self) -> int:
        block_size = self.block_size
242
        return (self.max_seq_len_to_capture + block_size - 1) // block_size
243

244
    def _prepare_model_input(
245
246
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
247
248
249
250
251
252
253
254
255
256
257
258
259
    ) -> ModelInput:
        """Prepare the model input based on a given sequence group.

        The API assumes seq_group_metadata_list is sorted by prefill -> decode.

        The result tensors and data structure also batches input in prefill
        -> decode order. For example,

        - input_tokens[:num_prefill_tokens] contains prefill tokens.
        - input_tokens[num_prefill_tokens:] contains decode tokens.

        If cuda graph is required, this API automatically pads inputs.
        """
260
261
262
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
263
264
265
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
266

267
        seq_lens: List[int] = []
268
269
        prefill_seq_lens: List[int] = []
        decode_seq_lens: List[int] = []
270
        context_lens: List[int] = []
271
        query_lens: List[int] = []
272
        block_tables: List[List[int]] = []
273
274
        multi_modal_kwargs_list: Dict[str,
                                      List[torch.Tensor]] = defaultdict(list)
275
276
277
278
        decode_only = True
        num_prefills = 0
        num_prefill_tokens = 0
        num_decode_tokens = 0
279

280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
        # The following fields are only for flashinfer
        # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
        # for the precise definition of the following fields.
        # An example:
        # request 1, page indices [0, 5, 8]
        # request 2, page indices [1, 6, 7]
        # request 3, page indices [3, 4]
        # paged_kv_indices is a concatenation of page indices of all requests:
        # [0, 5, 8, 1, 6, 7, 3, 4]
        # paged_kv_indptr is used to index into paged_kv_indices:
        # [0, 3, 6, 8]
        paged_kv_indices: List[int] = []
        # 0 at the beginning of paged_kv_indptr indicates the start of the
        # first request’s page indices in the paged_kv_indices list.
        paged_kv_indptr: List[int] = [0]
        # paged_kv_last_page_len is the length of the last page of each request
        paged_kv_last_page_len: List[int] = []

298
        if len(seq_group_metadata_list) == 0:
299
            return ModelInput.empty(self.device)
300

301
302
303
304
305
306
        if self.sliding_window is not None:
            sliding_window_blocks = (self.sliding_window + self.block_size -
                                     1) // self.block_size
            block_aligned_sliding_window = \
                sliding_window_blocks * self.block_size

307
308
        for seq_group_metadata in seq_group_metadata_list:
            seq_ids = list(seq_group_metadata.seq_data.keys())
309
            is_prompt = seq_group_metadata.is_prompt
310

311
            for seq_id in seq_ids:
312
313
314
315
316
317
318
319
320
                computed_block_nums = seq_group_metadata.computed_block_nums
                if (self.scheduler_config is not None
                        and self.scheduler_config.chunked_prefill_enabled
                        and not (computed_block_nums is None
                                 or computed_block_nums == [])):
                    raise RuntimeError(
                        "chunked prefill cannot be used with prefix caching "
                        "now.")

321
                seq_data = seq_group_metadata.seq_data[seq_id]
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
                if is_prompt:
                    context_len = seq_data.get_num_computed_tokens()
                else:
                    # get_num_computed_tokens is incorrect for spec decoding.
                    # So, we should have a special logic here.
                    # TODO(sang): Fix it.
                    context_len = seq_data.get_len() - 1

                seq_len = min(
                    seq_data.get_len(),
                    context_len + seq_group_metadata.token_chunk_size)
                if is_prompt:
                    tokens = seq_data.get_token_ids()[context_len:seq_len]
                else:
                    # Optimization. get_token_ids requires the entire copy of
                    # tokens.
                    tokens = [seq_data.get_last_token_id()]
339

340
341
342
343
344
345
346
                # Prefix cache was hit.
                # Prefix is not supported with sliding_window
                prefix_cache_hit = (computed_block_nums is not None
                                    and len(computed_block_nums) > 0
                                    and self.sliding_window is None
                                    and is_prompt)

347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
                # These are seq_len/context_len capped to the sliding window.
                # They are passed to decode kernel.
                # We still need original seq_len/context_len to compute slot
                # mapping (and input position) below.
                curr_sliding_window_blocks = None
                sliding_seq_len = seq_len
                sliding_context_len = context_len

                # TODO(sang): This is a hack to make sliding window work with
                # paged attn. We can remove it if we make paged attn kernel
                # to properly handle slinding window attn.
                if (self.sliding_window is not None and not is_prompt):
                    curr_sliding_window_blocks = sliding_window_blocks
                    if self.scheduler_config.use_v2_block_manager:
                        # number of elements in last block
                        suff_len = seq_len % self.block_size
                        sliding_seq_len = min(
                            seq_len, block_aligned_sliding_window + suff_len)
                        if suff_len > 0:
                            curr_sliding_window_blocks += 1
                    else:
                        sliding_seq_len = min(seq_len, self.sliding_window)
                    sliding_context_len = sliding_seq_len - 1

371
372
373
374
375
376
377
                # TODO(sang): Combine chunked prefill and prefix caching by
                # only allowing multiple of block_size chunk size.
                # NOTE: This only works for oooooooxxx style attention.
                if prefix_cache_hit:
                    assert computed_block_nums is not None
                    context_len = len(computed_block_nums) * self.block_size
                    tokens = tokens[context_len:]
378
379
380
381
382
383
384

                    # need to think what to set it to when we have both sliding
                    # window and prefix caching...
                    assert self.sliding_window is None, \
                        "Prefix caching is not supported with sliding window"
                    sliding_context_len = context_len

385
386
387
388
389
390
391
392
393
394
395
396
397
                    if self.attn_backend.get_name() == "flash-attn":
                        # NOTE(woosuk): For flash-attn, the block table should
                        # include the entries for the incoming prefill tokens.
                        # TODO(woosuk): This is a temporary fix. We should
                        # provide a unified interface for different backends.
                        block_table = seq_group_metadata.block_tables[seq_id]
                    else:
                        block_table = computed_block_nums
                elif (self.scheduler_config.chunked_prefill_enabled
                      or not is_prompt):
                    if seq_group_metadata.block_tables is not None:
                        # chunked prefill or decode
                        block_table = seq_group_metadata.block_tables[seq_id]
398
399
400
                        if curr_sliding_window_blocks is not None:
                            block_table = block_table[
                                -curr_sliding_window_blocks:]
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
                        if self.attn_backend.get_name() == "flashinfer":
                            paged_kv_indices.extend(block_table)
                            paged_kv_indptr.append(paged_kv_indptr[-1] +
                                                   len(block_table))
                            last_page_len = seq_data.get_len(
                            ) % self.block_size
                            if last_page_len == 0:
                                last_page_len = self.block_size
                            paged_kv_last_page_len.append(last_page_len)
                    else:
                        # Only happens when memory profiling runs.
                        block_table = []
                else:
                    # Prefill without chunked prefill or memory profiling.
                    block_table = []
                block_tables.append(block_table)

418
419
420
                seq_lens.append(sliding_seq_len)
                context_lens.append(sliding_context_len)
                query_len = sliding_seq_len - sliding_context_len
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
                query_lens.append(query_len)
                input_tokens.extend(tokens)
                input_positions.extend(list(range(context_len, seq_len)))
                lora_id = seq_group_metadata.lora_int_id

                if is_prompt:
                    assert len(seq_ids) == 1
                    num_prefills += 1
                    num_prefill_tokens += len(tokens)
                    decode_only = False
                    prefill_seq_lens.append(seq_len)
                else:
                    assert query_len == 1, (
                        "seq_len: {}, context_len: {}, query_len: {}".format(
                            seq_len, context_len, query_len))
                    num_decode_tokens += query_len
437
                    decode_seq_lens.append(sliding_seq_len)
438
439
440
441

                if lora_id > 0:
                    lora_requests.add(seq_group_metadata.lora_request)

442
                lora_index_mapping += [lora_id] * query_len
443
444
                lora_prompt_mapping.extend(
                    [lora_id] *
445
                    (query_len if seq_group_metadata.sampling_params
446
                     and seq_group_metadata.sampling_params.prompt_logprobs
447
                     is not None else 1))
448

449
450
451
452
453
454
455
456
457
458
459
                mm_data = seq_group_metadata.multi_modal_data
                if mm_data is not None:
                    # Process multi-modal data
                    if self.multi_modal_input_processor is None:
                        raise ValueError(
                            "Multi-modal inputs are only supported by "
                            "vision language models.")

                    mm_kwargs = self.multi_modal_input_processor(mm_data)
                    for k, v in mm_kwargs.items():
                        multi_modal_kwargs_list[k].append(v)
460
461
462
463
464
465
466
467

                if _is_block_tables_empty(seq_group_metadata.block_tables):
                    # During memory profiling, the block tables are not
                    # initialized yet. In this case, we just use a dummy
                    # slot mapping.
                    # In embeddings, the block tables are {seq_id: None}.
                    slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
                    continue
468

469
                # Compute the slot mapping.
470
471
                block_table = seq_group_metadata.block_tables[seq_id]

472
473
474
475
476
477
478
                # Mask the [0, start_idx) tokens of the prompt with
                # _PAD_SLOT_ID, where start_idx is max(0, seq_len -
                # sliding_window). For example, if the prompt len is 10,
                # sliding window is 8, and block size is 4, the first two
                # tokens are masked and the slot mapping will be
                # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
                start_idx = 0
479
                if self.sliding_window is not None:
480
                    if is_prompt:
481
482
                        assert self.scheduler_config.use_v2_block_manager \
                            or context_len == 0, (
483
                            "Prefix caching is currently not supported with "
484
                            "sliding window attention in V1 block manager")
485
486
487
488
489
490
491
492
493
494
495
496
497
498
                    # It is an optimization. When it is decoding, it is always
                    # 0. When prefill, we use it to not write slots to kv cache
                    # to save memory.
                    start_idx = max(0, query_len - self.sliding_window)

                for i in range(context_len, seq_len):
                    if i < start_idx:
                        slot_mapping.append(_PAD_SLOT_ID)
                        continue

                    block_number = block_table[i // self.block_size]
                    block_offset = i % self.block_size
                    slot = block_number * self.block_size + block_offset
                    slot_mapping.append(slot)
499

500
501
502
503
        batch_size = len(input_tokens)
        max_query_len = max(query_lens)
        max_prefill_seq_len = max(prefill_seq_lens, default=0)
        max_decode_seq_len = max(decode_seq_lens, default=0)
504

505
        # If cuda graph can be used, pad tensors accordingly.
506
        # See `capture_model` API for more details.
507
508
509
510
511
        # vLLM uses cuda graph only for decoding requests.
        use_captured_graph = (
            decode_only and not self.model_config.enforce_eager
            and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
            and max_decode_seq_len <= self.max_seq_len_to_capture)
512
513
514
515
        if use_captured_graph:
            graph_batch_size = _get_graph_batch_size(batch_size)
            assert graph_batch_size >= batch_size
            for _ in range(graph_batch_size - batch_size):
516
517
518
                input_tokens.append(0)
                input_positions.append(0)
                slot_mapping.append(_PAD_SLOT_ID)
519
                seq_lens.append(1)
520
                block_tables.append([])
521
                lora_index_mapping.append(0)
522
            batch_size = graph_batch_size
523
            num_decode_tokens = batch_size
524
525
526
527
528
529
530
531

        if use_captured_graph:
            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
            input_block_tables = self.graph_block_tables[:batch_size]
            for i, block_table in enumerate(block_tables):
                if block_table:
                    input_block_tables[i, :len(block_table)] = block_table
532
            block_tables = torch.tensor(input_block_tables, device=self.device)
533
        else:
534
535
            max_block_table_len = max(
                len(block_table) for block_table in block_tables)
536
            block_tables = make_tensor_with_pad(
537
                block_tables,
538
                max_len=max_block_table_len,
539
540
                pad=0,
                dtype=torch.int,
541
                device=self.device,
542
            )
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
        assert max_query_len > 0, ("query_lens: {}".format(query_lens))

        seq_lens_tensor = torch.tensor(seq_lens,
                                       dtype=torch.int,
                                       device=self.device)
        seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
                                    dtype=torch.int32,
                                    device=self.device)

        torch.cumsum(seq_lens_tensor,
                     dim=0,
                     dtype=seq_start_loc.dtype,
                     out=seq_start_loc[1:])

        input_tokens_tensor = torch.tensor(input_tokens,
                                           dtype=torch.long,
                                           device=self.device)
        input_positions_tensor = torch.tensor(input_positions,
                                              dtype=torch.long,
                                              device=self.device)
        slot_mapping_tensor = torch.tensor(slot_mapping,
                                           dtype=torch.long,
                                           device=self.device)
566

567
        if self.attn_backend.get_name() == "flashinfer":
568
569
570
571
572
            if not hasattr(self, "flashinfer_workspace_buffer"):
                # Allocate 16MB workspace buffer
                # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
                self.flashinfer_workspace_buffer = torch.empty(
                    16 * 1024 * 1024, dtype=torch.uint8, device=self.device)
573
            paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr,
574
575
                                                  dtype=torch.int,
                                                  device=self.device)
576
577
578
579
580
            paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
                                                   dtype=torch.int,
                                                   device=self.device)
            paged_kv_last_page_len_tensor = torch.tensor(
                paged_kv_last_page_len, dtype=torch.int, device=self.device)
581
582
583
            kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
                                                      self.model_config.dtype)
            attn_metadata = self.attn_backend.make_metadata(
584
585
586
587
                num_prefills=num_prefills,
                slot_mapping=slot_mapping_tensor,
                num_prefill_tokens=num_prefill_tokens,
                num_decode_tokens=num_decode_tokens,
588
                use_cuda_graph=False,
589
590
                max_prefill_seq_len=max_prefill_seq_len,
                block_tables=block_tables,
591
                workspace_buffer=self.flashinfer_workspace_buffer,
592
593
594
                paged_kv_indptr=paged_kv_indptr_tensor,
                paged_kv_indices=paged_kv_indices_tensor,
                paged_kv_last_page_len=paged_kv_last_page_len_tensor,
595
596
597
598
599
                num_qo_heads=self.model_config.get_num_attention_heads(
                    self.parallel_config),
                num_kv_heads=self.model_config.get_num_kv_heads(
                    self.parallel_config),
                head_dim=self.model_config.get_head_size(),
600
601
                page_size=16,
                seq_start_loc=seq_start_loc,
602
603
                data_type=kv_cache_dtype)
        else:
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
            context_lens_tensor = torch.tensor(context_lens,
                                               dtype=torch.int,
                                               device=self.device)
            query_lens_tensor = torch.tensor(query_lens,
                                             dtype=torch.long,
                                             device=self.device)
            query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
                                          dtype=torch.int32,
                                          device=self.device)

            torch.cumsum(query_lens_tensor,
                         dim=0,
                         dtype=query_start_loc.dtype,
                         out=query_start_loc[1:])

619
            attn_metadata = self.attn_backend.make_metadata(
620
621
622
623
624
                num_prefills=num_prefills,
                slot_mapping=slot_mapping_tensor,
                num_prefill_tokens=num_prefill_tokens,
                num_decode_tokens=num_decode_tokens,
                seq_lens=seq_lens,
625
                seq_lens_tensor=seq_lens_tensor,
626
627
628
629
630
631
                max_query_len=max_query_len,
                max_prefill_seq_len=max_prefill_seq_len,
                max_decode_seq_len=max_decode_seq_len,
                query_start_loc=query_start_loc,
                seq_start_loc=seq_start_loc,
                context_lens_tensor=context_lens_tensor,
632
633
634
                block_tables=block_tables,
                use_cuda_graph=use_captured_graph,
            )
635
636
637
638
639
640
641
642
643

        if self.lora_config:
            lora_mapping = LoRAMapping(
                lora_index_mapping,
                lora_prompt_mapping,
            )
        else:
            lora_mapping = None

644
645
646
647
648
        multi_modal_kwargs = {
            k: torch.cat(v, dim=0).to(self.device)
            for k, v in multi_modal_kwargs_list.items()
        }

649
650
651
        return ModelInput(
            input_tokens=input_tokens_tensor,
            input_positions=input_positions_tensor,
652
            attn_metadata=attn_metadata,
653
654
655
            seq_lens=seq_lens,
            query_lens=query_lens,
            lora_mapping=lora_mapping,
656
            lora_requests=lora_requests,
657
            multi_modal_kwargs=multi_modal_kwargs,
658
659
660
661
            slot_mapping=slot_mapping_tensor,
            num_prefill_tokens=num_prefill_tokens,
            num_decode_tokens=num_decode_tokens,
            num_prefills=num_prefills,
662
        )
663

664
665
    def prepare_input_tensors(
        self,
666
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
667
    ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
668
               Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]:
669
        if self.is_driver_worker:
670
            assert seq_group_metadata_list is not None
671
            # Prepare input tensors.
672
673
674
            (
                input_tokens,
                input_positions,
675
                attn_metadata,
676
677
                seq_lens,
                query_lens,
678
                lora_mapping,
679
                lora_requests,
680
                multi_modal_kwargs,
681
                slot_mapping,
682
683
684
685
                num_prefill_tokens,
                num_decode_tokens,
                num_prefills,
            ) = self._prepare_model_input(seq_group_metadata_list)
686
            sampling_metadata = SamplingMetadata.prepare(
687
688
                seq_group_metadata_list, seq_lens, query_lens, self.device,
                self.pin_memory)
689

690
691
692
693
694
            metadata_dict = {
                "input_tokens": input_tokens,
                "input_positions": input_positions,
                "selected_token_indices":
                sampling_metadata.selected_token_indices,
695
696
                "lora_requests": lora_requests,
                "lora_mapping": lora_mapping,
697
                "multi_modal_kwargs": multi_modal_kwargs,
698
699
700
701
                "num_prefill_tokens": num_prefill_tokens,
                "num_decode_tokens": num_decode_tokens,
                "slot_mapping": slot_mapping,
                "num_prefills": num_prefills,
702
            }
703
704
            if attn_metadata:
                metadata_dict.update(attn_metadata.asdict_zerocopy())
705
            broadcast_tensor_dict(metadata_dict, src=0)
706
        else:
707
            metadata_dict = broadcast_tensor_dict(src=0)
708
709
710
711
712
713
            input_tokens = metadata_dict.pop("input_tokens")
            input_positions = metadata_dict.pop("input_positions")
            selected_token_indices = metadata_dict.pop(
                "selected_token_indices")
            lora_mapping = metadata_dict.pop("lora_mapping")
            lora_requests = metadata_dict.pop("lora_requests")
714
            multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
715
716
            if metadata_dict:
                attn_metadata = self.attn_backend.make_metadata(
717
718
                    **metadata_dict)
            else:
719
                attn_metadata = None
720
721
            sampling_metadata = SamplingMetadata(
                seq_groups=None,
722
                selected_token_indices=selected_token_indices,
723
                categorized_sample_indices=None,
724
                num_prompts=0,
725
726
            )

727
        return (input_tokens, input_positions, attn_metadata,
728
                sampling_metadata, lora_requests, lora_mapping,
729
                multi_modal_kwargs)
730

731
732
733
    @torch.inference_mode()
    def execute_model(
        self,
734
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
735
        kv_caches: List[torch.Tensor],
736
    ) -> Optional[SamplerOutput]:
737
        (input_tokens, input_positions, attn_metadata, sampling_metadata,
738
         lora_requests, lora_mapping, multi_modal_kwargs
739
         ) = self.prepare_input_tensors(seq_group_metadata_list)
740
741
742
743

        if self.lora_config:
            self.set_active_loras(lora_requests, lora_mapping)

744
745
746
747
        # Currently cuda graph is only supported by the decode phase.
        prefill_meta = attn_metadata.prefill_metadata
        decode_meta = attn_metadata.decode_metadata
        if prefill_meta is None and decode_meta.use_cuda_graph:
748
749
750
751
            graph_batch_size = input_tokens.shape[0]
            model_executable = self.graph_runners[graph_batch_size]
        else:
            model_executable = self.model
752
753
754
755
756
757
758
759

        hidden_states = model_executable(
            input_ids=input_tokens,
            positions=input_positions,
            kv_caches=kv_caches,
            attn_metadata=attn_metadata,
            **multi_modal_kwargs,
        )
760

761
762
763
764
        # Compute the logits.
        logits = self.model.compute_logits(hidden_states, sampling_metadata)

        # Only perform sampling in the driver worker.
765
        if not self.is_driver_worker:
766
767
            return None

768
        # Sample the next token.
769
        output: SamplerOutput = self.model.sample(
770
            logits=logits,
771
772
            sampling_metadata=sampling_metadata,
        )
773

774
775
776
777
778
779
780
781
        if self.return_hidden_states:
            # we only need to pass hidden states of most recent token
            assert seq_group_metadata_list is not None
            if seq_group_metadata_list[0].is_prompt:
                hidden_states = hidden_states.index_select(
                    0, sampling_metadata.selected_token_indices)
            output.hidden_states = hidden_states

782
783
784
785
786
        return output

    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
787
        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
788
789
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs
790
791
792
793
        # This represents the maximum number of different requests
        # that will have unique loras, an therefore the max amount of memory
        # consumption create dummy lora request copies from the lora request
        # passed in, which contains a lora from the lora warmup path.
794
795
        dummy_lora_requests: List[LoRARequest] = []
        dummy_lora_requests_per_seq: List[LoRARequest] = []
796
        if self.lora_config:
797
            assert self.lora_manager is not None
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
            with self.lora_manager.dummy_lora_cache():
                for idx in range(self.lora_config.max_loras):
                    lora_id = idx + 1
                    dummy_lora_request = LoRARequest(
                        lora_name=f"warmup_{lora_id}",
                        lora_int_id=lora_id,
                        lora_local_path="/not/a/real/path",
                    )
                    self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                     rank=LORA_WARMUP_RANK)
                    dummy_lora_requests.append(dummy_lora_request)
                dummy_lora_requests_per_seq = [
                    dummy_lora_requests[idx % len(dummy_lora_requests)]
                    for idx in range(max_num_seqs)
                ]
813

814
815
816
        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
817
818
819
820
821
822
        # Additional GPU memory may be needed for vision encoding, which needs
        # to be accounted for when calculating the GPU blocks for
        # vLLM blocker manager.
        # To exercise the worst scenario for GPU memory consumption,
        # the number of seqs (batch_size) is chosen to maximize the number
        # of images processed.
823
824
825
826
        model_config = self.model_config
        vlm_config = self.vision_language_config

        if vlm_config:
827
828
            max_num_seqs = min(
                max_num_seqs,
829
                int(max_num_batched_tokens / vlm_config.image_feature_size))
830
831
832
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
833
834
835
836
837
838
839
840

            if vlm_config is None:
                seq_data = SequenceData([0] * seq_len)
                dummy_multi_modal_data = None
            else:
                seq_data, dummy_multi_modal_data = MULTIMODAL_REGISTRY \
                    .dummy_data_for_profiling(seq_len, model_config, vlm_config)

841
842
843
844
845
846
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
847
848
                lora_request=dummy_lora_requests_per_seq[group_id]
                if dummy_lora_requests_per_seq else None,
849
                multi_modal_data=dummy_multi_modal_data,
850
851
852
853
854
            )
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
855
        kv_caches = [None] * num_layers
856
        self.execute_model(seqs, kv_caches)
857
        torch.cuda.synchronize()
858
859
        return

860
    def remove_all_loras(self):
861
862
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
863
        self.lora_manager.remove_all_loras()
864

865
    def set_active_loras(self, lora_requests: Set[LoRARequest],
866
867
868
869
870
871
872
873
874
875
876
877
878
879
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        self.lora_manager.set_active_loras(lora_requests, lora_mapping)

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_lora(lora_id)
880
881
882
883
884

    def pin_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.pin_lora(lora_id)
885
886
887
888
889
890

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.list_loras()

891
    @torch.inference_mode()
892
    def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
893
894
895
896
897
898
899
900
901
902
903
904
        """Cuda graph capture a model.

        Note that CUDA graph's performance gain is negligible if number
        of batched tokens are larger than 200. And since CUDA graph
        requires fixed sized tensors, supporting large/variable batch
        size requires high GPU memory overhead. Thus, vLLM only captures
        decoding requests. Mixed batch (chunked prefill + decoding) or
        prefill requests are not captured.

        Since it is used for decoding-only, it assumes there's only 1 token
        per sequence in the batch.
        """
905
906
907
908
909
        assert not self.model_config.enforce_eager
        logger.info("Capturing the model for CUDA graphs. This may lead to "
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
                    "use '--enforce-eager' in the CLI.")
910
911
        logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
                    "If you are running out of memory, consider decreasing "
912
913
914
                    "`gpu_memory_utilization` or enforcing eager mode. "
                    "You can also reduce the `max_num_seqs` as needed "
                    "to decrease memory usage.")
915
916
917
918
        start_time = time.perf_counter()

        # Prepare dummy inputs. These will be reused for all batch sizes.
        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
919
920
921
        input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
922
        slot_mapping.fill_(_PAD_SLOT_ID)
923
        seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
924
925
        block_tables = torch.from_numpy(self.graph_block_tables).cuda()

926
927
928
929
        # Prepare buffer for outputs. These will be reused for all batch sizes.
        # It will be filled after the first graph capture.
        hidden_states: Optional[torch.Tensor] = None

930
931
932
933
934
935
        graph_batch_size = _get_graph_batch_size(
            self.scheduler_config.max_num_seqs)
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

936
        with graph_capture() as graph_capture_context:
937
938
            # NOTE: Capturing the largest batch size first may help reduce the
            # memory usage of CUDA graph.
939
            for batch_size in reversed(batch_size_capture_list):
940
                # Create dummy attn_metadata.
941
942
943
944
945
                attn_metadata = self.attn_backend.make_metadata(
                    num_prefills=0,
                    num_prefill_tokens=0,
                    num_decode_tokens=batch_size,
                    slot_mapping=slot_mapping[:batch_size],
946
947
948
                    seq_lens=None,
                    seq_lens_tensor=seq_lens[:batch_size],
                    max_query_len=None,
949
950
951
                    max_prefill_seq_len=0,
                    max_decode_seq_len=self.max_seq_len_to_capture,
                    query_start_loc=None,
952
                    seq_start_loc=None,
953
                    context_lens_tensor=None,
954
955
                    block_tables=block_tables[:batch_size],
                    use_cuda_graph=True,
956
                )
957

958
959
960
961
962
963
964
965
                if self.lora_config:
                    lora_mapping = LoRAMapping(
                        [0] * batch_size,
                        [0] * batch_size,
                    )
                    self.set_active_loras(set(), lora_mapping)

                graph_runner = CUDAGraphRunner(self.model)
966
                hidden_states = graph_runner.capture(
967
968
                    input_tokens[:batch_size],
                    input_positions[:batch_size],
969
970
                    hidden_states[:batch_size]
                    if hidden_states is not None else None,
971
                    kv_caches,
972
                    attn_metadata,
973
                    memory_pool=self.graph_memory_pool,
974
                    stream=graph_capture_context.stream,
975
                )
976
977
                self.graph_memory_pool = graph_runner.graph.pool()
                self.graph_runners[batch_size] = graph_runner
978
979
980
981

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        # This usually takes < 10 seconds.
982
        logger.info("Graph capturing finished in %.0f secs.", elapsed_time)
983

984
985
986
987
    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

988
989
990
991
992
993
994
995

class CUDAGraphRunner:

    def __init__(self, model: nn.Module):
        self.model = model
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

996
997
998
999
1000
1001
1002
        self._graph: Optional[torch.cuda.CUDAGraph] = None

    @property
    def graph(self):
        assert self._graph is not None
        return self._graph

1003
1004
1005
1006
    def capture(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1007
        hidden_states: Optional[torch.Tensor],
1008
1009
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1010
1011
        memory_pool: Optional[Tuple[int, int]],
        stream: torch.cuda.Stream,
1012
        **kwargs,
1013
    ) -> torch.Tensor:
1014
        assert self._graph is None
1015
        # Run the model a few times without capturing the graph.
1016
1017
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
1018
1019
1020
1021
1022
1023
1024
1025
1026
        # Note one iteration is not enough for torch.jit.script
        for _ in range(_NUM_WARMUP_ITERS):
            self.model(
                input_ids,
                positions,
                kv_caches,
                attn_metadata,
                **kwargs,
            )
1027
1028
1029
1030
1031
        torch.cuda.synchronize()

        # Capture the graph.
        self._graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
1032
            output_hidden_states = self.model(
1033
1034
1035
                input_ids,
                positions,
                kv_caches,
1036
                attn_metadata,
1037
                **kwargs,
1038
            )
1039
1040
1041
1042
1043
1044
1045
1046
            if hidden_states is not None:
                hidden_states.copy_(output_hidden_states)
            else:
                hidden_states = output_hidden_states
            del output_hidden_states
            # make sure `output_hidden_states` is deleted
            # in the graph's memory pool
            gc.collect()
1047
1048
1049
1050
1051
1052
1053
        torch.cuda.synchronize()

        # Save the input and output buffers.
        self.input_buffers = {
            "input_ids": input_ids,
            "positions": positions,
            "kv_caches": kv_caches,
1054
            "slot_mapping": attn_metadata.slot_mapping,
1055
            "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
1056
            "block_tables": attn_metadata.decode_metadata.block_tables,
1057
1058
        }
        self.output_buffers = {"hidden_states": hidden_states}
1059
        return hidden_states
1060
1061
1062
1063
1064

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1065
1066
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1067
        **kwargs,
1068
1069
1070
1071
1072
    ) -> torch.Tensor:
        # KV caches are fixed tensors, so we don't need to copy them.
        del kv_caches

        # Copy the input tensors to the input buffers.
1073
1074
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
        self.input_buffers["positions"].copy_(positions, non_blocking=True)
1075
        self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
1076
                                                 non_blocking=True)
1077
1078
        self.input_buffers["seq_lens_tensor"].copy_(
            attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
1079
1080
        self.input_buffers["block_tables"].copy_(
            attn_metadata.decode_metadata.block_tables, non_blocking=True)
1081
1082
1083
1084
1085
1086
1087
1088
1089
        # Run the graph.
        self.graph.replay()

        # Return the output tensor.
        return self.output_buffers["hidden_states"]

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

1090

1091
def _get_graph_batch_size(batch_size: int) -> int:
1092
1093
1094
1095
1096
    """Returns the padded batch size given actual batch size.

    Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
    2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
    """
1097
1098
1099
1100
1101
    if batch_size <= 2:
        return batch_size
    elif batch_size <= 4:
        return 4
    else:
1102
1103
        return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
                _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
1104
1105


1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
def _is_block_tables_empty(block_tables: Union[None, Dict]):
    """
    Check if block_tables is None or a dictionary with all None values.
    """
    if block_tables is None:
        return True
    if isinstance(block_tables, dict) and all(
            value is None for value in block_tables.values()):
        return True
    return False