"vscode:/vscode.git/clone" did not exist on "35fda7b4af556e7eeef2b5dcb3638435382b2576"
model_runner.py 49.9 KB
Newer Older
1
import dataclasses
2
import gc
3
import time
4
import warnings
5
from collections import defaultdict
6
7
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
                    TypeVar, Union)
8

9
import numpy as np
10
import torch
11
import torch.nn as nn
12

13
from vllm.attention import AttentionMetadata, get_attn_backend
14
15
16
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
                         ModelConfig, ParallelConfig, SchedulerConfig,
                         VisionLanguageConfig)
17
from vllm.distributed.parallel_state import graph_capture
18
from vllm.logger import init_logger
19
20
21
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
22
from vllm.model_executor import SamplingMetadata
23
from vllm.model_executor.model_loader import get_model
24
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
25
from vllm.multimodal import MULTIMODAL_REGISTRY
26
from vllm.sampling_params import SamplingParams
27
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
28
29
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
                        is_pin_memory_available, make_tensor_with_pad)
30
31
32
33
34
35
36
37
38
from vllm.worker.model_runner_base import (
    ModelRunnerBase, ModelRunnerInputBase,
    _add_attn_metadata_broadcastable_dict,
    _add_sampling_metadata_broadcastable_dict,
    _init_attn_metadata_from_tensor_dict,
    _init_sampling_metadata_from_tensor_dict)

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
39
40
41
42

logger = init_logger(__name__)

_PAD_SLOT_ID = -1
43
LORA_WARMUP_RANK = 8
44
45
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
46
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
47
48
49
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
    _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
50
_NUM_WARMUP_ITERS = 2
51

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")


@dataclasses.dataclass(frozen=True)
class ModelInputForGPU(ModelRunnerInputBase):
    """
    This base class contains metadata needed for the base model forward pass
    but not metadata for possible additional steps, e.g., sampling. Model
    runners that run additional steps should subclass this method to add
    additional fields.
    """
    input_tokens: Optional[torch.Tensor] = None
    input_positions: Optional[torch.Tensor] = None
    seq_lens: Optional[List[int]] = None
    query_lens: Optional[List[int]] = None
    lora_mapping: Optional["LoRAMapping"] = None
    lora_requests: Optional[Set[LoRARequest]] = None
    attn_metadata: Optional["AttentionMetadata"] = None
    multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None

    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
            "lora_requests": self.lora_requests,
            "lora_mapping": self.lora_mapping,
            "multi_modal_kwargs": self.multi_modal_kwargs,
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        return tensor_dict
82
83

    @classmethod
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    def from_broadcasted_tensor_dict(
        cls: Type[TModelInputForGPU],
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> TModelInputForGPU:
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


@dataclasses.dataclass(frozen=True)
class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
    """
    Used by the ModelRunner.
    """
    sampling_metadata: Optional["SamplingMetadata"] = None
    # Used for speculative decoding. We do not broadcast it because it is only
    # used by the driver worker.
    is_prompt: Optional[bool] = None

    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
            "lora_requests": self.lora_requests,
            "lora_mapping": self.lora_mapping,
            "multi_modal_kwargs": self.multi_modal_kwargs,
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        _add_sampling_metadata_broadcastable_dict(tensor_dict,
                                                  self.sampling_metadata)
        return tensor_dict
117

118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    @classmethod
    def from_broadcasted_tensor_dict(
        cls,
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "ModelInputForGPUWithSamplingMetadata":
        tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
    """
    Helper class for shared methods between GPU model runners.
    """
    _model_input_cls: Type[TModelInputForGPU]
136
137
138
139
140
141

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
142
        device_config: DeviceConfig,
143
        cache_config: CacheConfig,
144
        load_config: LoadConfig,
145
        lora_config: Optional[LoRAConfig],
146
        kv_cache_dtype: Optional[str] = "auto",
147
        is_driver_worker: bool = False,
148
        vision_language_config: Optional[VisionLanguageConfig] = None,
149
        return_hidden_states: bool = False,
150
151
152
153
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
154
155
        self.device_config = device_config
        self.cache_config = cache_config
156
        self.lora_config = lora_config
157
        self.load_config = load_config
158
        self.is_driver_worker = is_driver_worker
159
        self.vision_language_config = vision_language_config
160
        self.return_hidden_states = return_hidden_states
161

162
        self.device = self.device_config.device
163
        self.pin_memory = is_pin_memory_available()
164

165
166
167
168
        self.kv_cache_dtype = kv_cache_dtype
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
169
        self.graph_runners: Dict[int, CUDAGraphRunner] = {}
170
171
172
        self.graph_memory_pool: Optional[Tuple[
            int, int]] = None  # Set during graph capture.
        # When using CUDA graph, the input block tables must be padded to
173
        # max_seq_len_to_capture. However, creating the block table in
174
175
176
177
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
        # (max batch size to capture, max context len to capture / block size).
178
179
180
        self.graph_block_tables = np.zeros(
            (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
            dtype=np.int32)
181
182
        num_attn_heads = self.model_config.get_num_attention_heads(
            self.parallel_config)
183
        self.attn_backend = get_attn_backend(
184
            num_attn_heads,
185
186
187
188
189
190
            self.model_config.get_head_size(),
            self.model_config.get_num_kv_heads(self.parallel_config),
            self.model_config.get_sliding_window(),
            self.model_config.dtype,
            self.kv_cache_dtype,
            self.block_size,
191
        ) if num_attn_heads else None
192

193
194
195
196
197
198
199
200
201
202
        # Create processor for multi-modal data
        if self.vision_language_config is not None:
            self.multi_modal_input_processor = MULTIMODAL_REGISTRY \
                .create_input_processor(
                    self.model_config,
                    self.vision_language_config,
                )
        else:
            self.multi_modal_input_processor = None

203
        # Lazy initialization
204
        self.model: nn.Module  # Set after load_model
205
206
        # Set if the backend is flashinfer.
        self.flashinfer_workspace_buffer: torch.Tensor
207
208
        # Set after load_model.
        self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
209

210
    def load_model(self) -> None:
211
        with CudaMemoryProfiler() as m:
212
            self.model = get_model(
213
214
215
                model_config=self.model_config,
                device_config=self.device_config,
                load_config=self.load_config,
216
217
218
                lora_config=self.lora_config,
                vision_language_config=self.vision_language_config,
                parallel_config=self.parallel_config,
219
                scheduler_config=self.scheduler_config,
220
                cache_config=self.cache_config,
221
            )
222
223

        self.model_memory_usage = m.consumed_memory
224
225
        logger.info("Loading model weights took %.4f GB",
                    self.model_memory_usage / float(2**30))
226
227

        if self.lora_config:
228
229
230
            assert hasattr(self.model, "supported_lora_modules"
                           ) and self.model.supported_lora_modules, (
                               "Model does not support LoRA")
Terry's avatar
Terry committed
231
232
233
234
235
            assert hasattr(
                self.model,
                "embedding_modules"), "Model does not have embedding_modules"
            assert hasattr(self.model, "embedding_padding_modules"
                           ), "Model does not have embedding_padding_modules"
236
237
            self.lora_manager = LRUCacheWorkerLoRAManager(
                self.scheduler_config.max_num_seqs,
238
239
240
241
242
243
244
245
246
                self.scheduler_config.max_num_batched_tokens,
                self.vocab_size,
                self.lora_config,
                self.device,
                self.model.embedding_modules,
                self.model.embedding_padding_modules,
                max_position_embeddings=self.model.config.
                max_position_embeddings,
            )
247
            self.model = self.lora_manager.create_lora_manager(self.model)
248

249
        if self.kv_cache_dtype == "fp8" and is_hip():
250
251
252
            # Currently only ROCm accepts kv-cache scaling factors
            # via quantization_param_path and this will be deprecated
            # in the future.
253
254
            if self.model_config.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
255
256
257
258
259
260
                    warnings.warn(
                        "Loading kv cache scaling factor from JSON is "
                        "deprecated and will be removed. Please include "
                        "kv cache scaling factors in the model checkpoint.",
                        FutureWarning,
                        stacklevel=2)
261
262
                    self.model.load_kv_cache_scales(
                        self.model_config.quantization_param_path)
263
264
                    logger.info("Loaded KV cache scaling factors from %s",
                                self.model_config.quantization_param_path)
265
                else:
266
267
268
269
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__)
270
            else:
271
272
273
274
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!")
275

276
277
278
279
280
281
282
283
284
285
286
287
288
289
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        from vllm.model_executor.model_loader.loader import ShardedStateLoader
        ShardedStateLoader.save_model(
            self.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

290
291
292
293
294
295
296
297
298
299
    def save_tensorized_model(
        self,
        tensorizer_config: TensorizerConfig,
    ) -> None:
        from vllm.model_executor.model_loader.loader import TensorizerLoader
        TensorizerLoader.save_model(
            self.model,
            tensorizer_config=tensorizer_config,
        )

300
301
    def get_max_block_per_batch(self) -> int:
        block_size = self.block_size
302
        return (self.max_seq_len_to_capture + block_size - 1) // block_size
303

304
    def _prepare_model_input_tensors(
305
306
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
307
308
309
310
    ) -> TModelInputForGPU:
        """Helper method to prepare the model input based on a given sequence
        group. Prepares metadata needed for the base model forward pass but not
        metadata for possible additional steps, e.g., sampling.
311
312
313
314
315
316
317
318
319
320
321

        The API assumes seq_group_metadata_list is sorted by prefill -> decode.

        The result tensors and data structure also batches input in prefill
        -> decode order. For example,

        - input_tokens[:num_prefill_tokens] contains prefill tokens.
        - input_tokens[num_prefill_tokens:] contains decode tokens.

        If cuda graph is required, this API automatically pads inputs.
        """
322
323
324
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
325
326
327
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
328

329
        seq_lens: List[int] = []
330
331
        prefill_seq_lens: List[int] = []
        decode_seq_lens: List[int] = []
332
        context_lens: List[int] = []
333
        query_lens: List[int] = []
334
        block_tables: List[List[int]] = []
335
336
        multi_modal_kwargs_list: Dict[str,
                                      List[torch.Tensor]] = defaultdict(list)
337
338
339
340
        decode_only = True
        num_prefills = 0
        num_prefill_tokens = 0
        num_decode_tokens = 0
341

342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        # The following fields are only for flashinfer
        # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
        # for the precise definition of the following fields.
        # An example:
        # request 1, page indices [0, 5, 8]
        # request 2, page indices [1, 6, 7]
        # request 3, page indices [3, 4]
        # paged_kv_indices is a concatenation of page indices of all requests:
        # [0, 5, 8, 1, 6, 7, 3, 4]
        # paged_kv_indptr is used to index into paged_kv_indices:
        # [0, 3, 6, 8]
        paged_kv_indices: List[int] = []
        # 0 at the beginning of paged_kv_indptr indicates the start of the
        # first request’s page indices in the paged_kv_indices list.
        paged_kv_indptr: List[int] = [0]
        # paged_kv_last_page_len is the length of the last page of each request
        paged_kv_last_page_len: List[int] = []

360
        if len(seq_group_metadata_list) == 0:
361
            return self._model_input_cls()
362

363
364
365
366
367
368
        if self.sliding_window is not None:
            sliding_window_blocks = (self.sliding_window + self.block_size -
                                     1) // self.block_size
            block_aligned_sliding_window = \
                sliding_window_blocks * self.block_size

369
370
        for seq_group_metadata in seq_group_metadata_list:
            seq_ids = list(seq_group_metadata.seq_data.keys())
371
            is_prompt = seq_group_metadata.is_prompt
372

373
            for seq_id in seq_ids:
374
375
376
377
378
379
380
381
382
                computed_block_nums = seq_group_metadata.computed_block_nums
                if (self.scheduler_config is not None
                        and self.scheduler_config.chunked_prefill_enabled
                        and not (computed_block_nums is None
                                 or computed_block_nums == [])):
                    raise RuntimeError(
                        "chunked prefill cannot be used with prefix caching "
                        "now.")

383
                seq_data = seq_group_metadata.seq_data[seq_id]
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
                if is_prompt:
                    context_len = seq_data.get_num_computed_tokens()
                else:
                    # get_num_computed_tokens is incorrect for spec decoding.
                    # So, we should have a special logic here.
                    # TODO(sang): Fix it.
                    context_len = seq_data.get_len() - 1

                seq_len = min(
                    seq_data.get_len(),
                    context_len + seq_group_metadata.token_chunk_size)
                if is_prompt:
                    tokens = seq_data.get_token_ids()[context_len:seq_len]
                else:
                    # Optimization. get_token_ids requires the entire copy of
                    # tokens.
                    tokens = [seq_data.get_last_token_id()]
401

402
403
404
405
406
407
408
                # Prefix cache was hit.
                # Prefix is not supported with sliding_window
                prefix_cache_hit = (computed_block_nums is not None
                                    and len(computed_block_nums) > 0
                                    and self.sliding_window is None
                                    and is_prompt)

409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
                # These are seq_len/context_len capped to the sliding window.
                # They are passed to decode kernel.
                # We still need original seq_len/context_len to compute slot
                # mapping (and input position) below.
                curr_sliding_window_blocks = None
                sliding_seq_len = seq_len
                sliding_context_len = context_len

                # TODO(sang): This is a hack to make sliding window work with
                # paged attn. We can remove it if we make paged attn kernel
                # to properly handle slinding window attn.
                if (self.sliding_window is not None and not is_prompt):
                    curr_sliding_window_blocks = sliding_window_blocks
                    if self.scheduler_config.use_v2_block_manager:
                        # number of elements in last block
                        suff_len = seq_len % self.block_size
                        sliding_seq_len = min(
                            seq_len, block_aligned_sliding_window + suff_len)
                        if suff_len > 0:
                            curr_sliding_window_blocks += 1
                    else:
                        sliding_seq_len = min(seq_len, self.sliding_window)
                    sliding_context_len = sliding_seq_len - 1

433
434
435
436
437
438
439
                # TODO(sang): Combine chunked prefill and prefix caching by
                # only allowing multiple of block_size chunk size.
                # NOTE: This only works for oooooooxxx style attention.
                if prefix_cache_hit:
                    assert computed_block_nums is not None
                    context_len = len(computed_block_nums) * self.block_size
                    tokens = tokens[context_len:]
440
441
442
443
444
445
446

                    # need to think what to set it to when we have both sliding
                    # window and prefix caching...
                    assert self.sliding_window is None, \
                        "Prefix caching is not supported with sliding window"
                    sliding_context_len = context_len

447
448
449
450
451
452
453
454
455
456
457
458
459
                    if self.attn_backend.get_name() == "flash-attn":
                        # NOTE(woosuk): For flash-attn, the block table should
                        # include the entries for the incoming prefill tokens.
                        # TODO(woosuk): This is a temporary fix. We should
                        # provide a unified interface for different backends.
                        block_table = seq_group_metadata.block_tables[seq_id]
                    else:
                        block_table = computed_block_nums
                elif (self.scheduler_config.chunked_prefill_enabled
                      or not is_prompt):
                    if seq_group_metadata.block_tables is not None:
                        # chunked prefill or decode
                        block_table = seq_group_metadata.block_tables[seq_id]
460
461
462
                        if curr_sliding_window_blocks is not None:
                            block_table = block_table[
                                -curr_sliding_window_blocks:]
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
                        if self.attn_backend.get_name() == "flashinfer":
                            paged_kv_indices.extend(block_table)
                            paged_kv_indptr.append(paged_kv_indptr[-1] +
                                                   len(block_table))
                            last_page_len = seq_data.get_len(
                            ) % self.block_size
                            if last_page_len == 0:
                                last_page_len = self.block_size
                            paged_kv_last_page_len.append(last_page_len)
                    else:
                        # Only happens when memory profiling runs.
                        block_table = []
                else:
                    # Prefill without chunked prefill or memory profiling.
                    block_table = []
                block_tables.append(block_table)

480
481
482
                seq_lens.append(sliding_seq_len)
                context_lens.append(sliding_context_len)
                query_len = sliding_seq_len - sliding_context_len
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
                query_lens.append(query_len)
                input_tokens.extend(tokens)
                input_positions.extend(list(range(context_len, seq_len)))
                lora_id = seq_group_metadata.lora_int_id

                if is_prompt:
                    assert len(seq_ids) == 1
                    num_prefills += 1
                    num_prefill_tokens += len(tokens)
                    decode_only = False
                    prefill_seq_lens.append(seq_len)
                else:
                    assert query_len == 1, (
                        "seq_len: {}, context_len: {}, query_len: {}".format(
                            seq_len, context_len, query_len))
                    num_decode_tokens += query_len
499
                    decode_seq_lens.append(sliding_seq_len)
500
501
502
503

                if lora_id > 0:
                    lora_requests.add(seq_group_metadata.lora_request)

504
                lora_index_mapping += [lora_id] * query_len
505
506
                lora_prompt_mapping.extend(
                    [lora_id] *
507
                    (query_len if seq_group_metadata.sampling_params
508
                     and seq_group_metadata.sampling_params.prompt_logprobs
509
                     is not None else 1))
510

511
512
513
514
515
516
517
518
519
520
521
                mm_data = seq_group_metadata.multi_modal_data
                if mm_data is not None:
                    # Process multi-modal data
                    if self.multi_modal_input_processor is None:
                        raise ValueError(
                            "Multi-modal inputs are only supported by "
                            "vision language models.")

                    mm_kwargs = self.multi_modal_input_processor(mm_data)
                    for k, v in mm_kwargs.items():
                        multi_modal_kwargs_list[k].append(v)
522
523
524
525
526
527
528
529

                if _is_block_tables_empty(seq_group_metadata.block_tables):
                    # During memory profiling, the block tables are not
                    # initialized yet. In this case, we just use a dummy
                    # slot mapping.
                    # In embeddings, the block tables are {seq_id: None}.
                    slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
                    continue
530

531
                # Compute the slot mapping.
532
533
                block_table = seq_group_metadata.block_tables[seq_id]

534
535
536
537
538
539
540
                # Mask the [0, start_idx) tokens of the prompt with
                # _PAD_SLOT_ID, where start_idx is max(0, seq_len -
                # sliding_window). For example, if the prompt len is 10,
                # sliding window is 8, and block size is 4, the first two
                # tokens are masked and the slot mapping will be
                # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
                start_idx = 0
541
                if self.sliding_window is not None:
542
                    if is_prompt:
543
544
                        assert self.scheduler_config.use_v2_block_manager \
                            or context_len == 0, (
545
                            "Prefix caching is currently not supported with "
546
                            "sliding window attention in V1 block manager")
547
548
549
550
551
552
553
554
555
556
557
558
559
560
                    # It is an optimization. When it is decoding, it is always
                    # 0. When prefill, we use it to not write slots to kv cache
                    # to save memory.
                    start_idx = max(0, query_len - self.sliding_window)

                for i in range(context_len, seq_len):
                    if i < start_idx:
                        slot_mapping.append(_PAD_SLOT_ID)
                        continue

                    block_number = block_table[i // self.block_size]
                    block_offset = i % self.block_size
                    slot = block_number * self.block_size + block_offset
                    slot_mapping.append(slot)
561

562
563
564
565
        batch_size = len(input_tokens)
        max_query_len = max(query_lens)
        max_prefill_seq_len = max(prefill_seq_lens, default=0)
        max_decode_seq_len = max(decode_seq_lens, default=0)
566

567
        # If cuda graph can be used, pad tensors accordingly.
568
        # See `capture_model` API for more details.
569
570
571
572
573
        # vLLM uses cuda graph only for decoding requests.
        use_captured_graph = (
            decode_only and not self.model_config.enforce_eager
            and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
            and max_decode_seq_len <= self.max_seq_len_to_capture)
574
575
576
577
        if use_captured_graph:
            graph_batch_size = _get_graph_batch_size(batch_size)
            assert graph_batch_size >= batch_size
            for _ in range(graph_batch_size - batch_size):
578
579
580
                input_tokens.append(0)
                input_positions.append(0)
                slot_mapping.append(_PAD_SLOT_ID)
581
                seq_lens.append(1)
582
                block_tables.append([])
583
                lora_index_mapping.append(0)
584
            batch_size = graph_batch_size
585
            num_decode_tokens = batch_size
586
587
588
589
590
591
592
593

        if use_captured_graph:
            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
            input_block_tables = self.graph_block_tables[:batch_size]
            for i, block_table in enumerate(block_tables):
                if block_table:
                    input_block_tables[i, :len(block_table)] = block_table
594
            block_tables = torch.tensor(input_block_tables, device=self.device)
595
        else:
596
597
            max_block_table_len = max(
                len(block_table) for block_table in block_tables)
598
            block_tables = make_tensor_with_pad(
599
                block_tables,
600
                max_len=max_block_table_len,
601
602
                pad=0,
                dtype=torch.int,
603
                device=self.device,
604
            )
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
        assert max_query_len > 0, ("query_lens: {}".format(query_lens))

        seq_lens_tensor = torch.tensor(seq_lens,
                                       dtype=torch.int,
                                       device=self.device)
        seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
                                    dtype=torch.int32,
                                    device=self.device)

        torch.cumsum(seq_lens_tensor,
                     dim=0,
                     dtype=seq_start_loc.dtype,
                     out=seq_start_loc[1:])

        input_tokens_tensor = torch.tensor(input_tokens,
                                           dtype=torch.long,
                                           device=self.device)
        input_positions_tensor = torch.tensor(input_positions,
                                              dtype=torch.long,
                                              device=self.device)
        slot_mapping_tensor = torch.tensor(slot_mapping,
                                           dtype=torch.long,
                                           device=self.device)
628

629
        if self.attn_backend.get_name() == "flashinfer":
630
631
632
633
634
            if not hasattr(self, "flashinfer_workspace_buffer"):
                # Allocate 16MB workspace buffer
                # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
                self.flashinfer_workspace_buffer = torch.empty(
                    16 * 1024 * 1024, dtype=torch.uint8, device=self.device)
635
            paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr,
636
637
                                                  dtype=torch.int,
                                                  device=self.device)
638
639
640
641
642
            paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
                                                   dtype=torch.int,
                                                   device=self.device)
            paged_kv_last_page_len_tensor = torch.tensor(
                paged_kv_last_page_len, dtype=torch.int, device=self.device)
643
644
645
            kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
                                                      self.model_config.dtype)
            attn_metadata = self.attn_backend.make_metadata(
646
647
648
649
                num_prefills=num_prefills,
                slot_mapping=slot_mapping_tensor,
                num_prefill_tokens=num_prefill_tokens,
                num_decode_tokens=num_decode_tokens,
650
                use_cuda_graph=False,
651
652
                max_prefill_seq_len=max_prefill_seq_len,
                block_tables=block_tables,
653
                workspace_buffer=self.flashinfer_workspace_buffer,
654
655
656
                paged_kv_indptr=paged_kv_indptr_tensor,
                paged_kv_indices=paged_kv_indices_tensor,
                paged_kv_last_page_len=paged_kv_last_page_len_tensor,
657
658
659
660
661
                num_qo_heads=self.model_config.get_num_attention_heads(
                    self.parallel_config),
                num_kv_heads=self.model_config.get_num_kv_heads(
                    self.parallel_config),
                head_dim=self.model_config.get_head_size(),
662
663
                page_size=16,
                seq_start_loc=seq_start_loc,
664
665
                data_type=kv_cache_dtype)
        else:
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
            context_lens_tensor = torch.tensor(context_lens,
                                               dtype=torch.int,
                                               device=self.device)
            query_lens_tensor = torch.tensor(query_lens,
                                             dtype=torch.long,
                                             device=self.device)
            query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
                                          dtype=torch.int32,
                                          device=self.device)

            torch.cumsum(query_lens_tensor,
                         dim=0,
                         dtype=query_start_loc.dtype,
                         out=query_start_loc[1:])

681
            attn_metadata = self.attn_backend.make_metadata(
682
683
684
685
686
                num_prefills=num_prefills,
                slot_mapping=slot_mapping_tensor,
                num_prefill_tokens=num_prefill_tokens,
                num_decode_tokens=num_decode_tokens,
                seq_lens=seq_lens,
687
                seq_lens_tensor=seq_lens_tensor,
688
689
690
691
692
693
                max_query_len=max_query_len,
                max_prefill_seq_len=max_prefill_seq_len,
                max_decode_seq_len=max_decode_seq_len,
                query_start_loc=query_start_loc,
                seq_start_loc=seq_start_loc,
                context_lens_tensor=context_lens_tensor,
694
695
696
                block_tables=block_tables,
                use_cuda_graph=use_captured_graph,
            )
697
698
699
700
701
702
703
704
705

        if self.lora_config:
            lora_mapping = LoRAMapping(
                lora_index_mapping,
                lora_prompt_mapping,
            )
        else:
            lora_mapping = None

706
707
708
709
710
        multi_modal_kwargs = {
            k: torch.cat(v, dim=0).to(self.device)
            for k, v in multi_modal_kwargs_list.items()
        }

711
        return self._model_input_cls(
712
713
            input_tokens=input_tokens_tensor,
            input_positions=input_positions_tensor,
714
            attn_metadata=attn_metadata,
715
716
717
            seq_lens=seq_lens,
            query_lens=query_lens,
            lora_mapping=lora_mapping,
718
            lora_requests=lora_requests,
719
            multi_modal_kwargs=multi_modal_kwargs,
720
        )
721

722
723
724
    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
725
        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
726
727
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs
728
729
730
731
        # This represents the maximum number of different requests
        # that will have unique loras, an therefore the max amount of memory
        # consumption create dummy lora request copies from the lora request
        # passed in, which contains a lora from the lora warmup path.
732
733
        dummy_lora_requests: List[LoRARequest] = []
        dummy_lora_requests_per_seq: List[LoRARequest] = []
734
        if self.lora_config:
735
            assert self.lora_manager is not None
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
            with self.lora_manager.dummy_lora_cache():
                for idx in range(self.lora_config.max_loras):
                    lora_id = idx + 1
                    dummy_lora_request = LoRARequest(
                        lora_name=f"warmup_{lora_id}",
                        lora_int_id=lora_id,
                        lora_local_path="/not/a/real/path",
                    )
                    self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                     rank=LORA_WARMUP_RANK)
                    dummy_lora_requests.append(dummy_lora_request)
                dummy_lora_requests_per_seq = [
                    dummy_lora_requests[idx % len(dummy_lora_requests)]
                    for idx in range(max_num_seqs)
                ]
751

752
753
754
        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
755
756
757
758
759
760
        # Additional GPU memory may be needed for vision encoding, which needs
        # to be accounted for when calculating the GPU blocks for
        # vLLM blocker manager.
        # To exercise the worst scenario for GPU memory consumption,
        # the number of seqs (batch_size) is chosen to maximize the number
        # of images processed.
761
762
763
764
        model_config = self.model_config
        vlm_config = self.vision_language_config

        if vlm_config:
765
766
            max_num_seqs = min(
                max_num_seqs,
767
                int(max_num_batched_tokens / vlm_config.image_feature_size))
768
769
770
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
771
772
773
774
775
776
777
778

            if vlm_config is None:
                seq_data = SequenceData([0] * seq_len)
                dummy_multi_modal_data = None
            else:
                seq_data, dummy_multi_modal_data = MULTIMODAL_REGISTRY \
                    .dummy_data_for_profiling(seq_len, model_config, vlm_config)

779
780
781
782
783
784
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
785
786
                lora_request=dummy_lora_requests_per_seq[group_id]
                if dummy_lora_requests_per_seq else None,
787
                multi_modal_data=dummy_multi_modal_data,
788
789
790
791
792
            )
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
793
        kv_caches = [None] * num_layers
794
795
        model_input = self.prepare_model_input(seqs)
        self.execute_model(model_input, kv_caches)
796
        torch.cuda.synchronize()
797
798
        return

799
    def remove_all_loras(self):
800
801
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
802
        self.lora_manager.remove_all_loras()
803

804
    def set_active_loras(self, lora_requests: Set[LoRARequest],
805
806
807
808
809
810
811
812
813
814
815
816
817
818
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        self.lora_manager.set_active_loras(lora_requests, lora_mapping)

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_lora(lora_id)
819
820
821
822
823

    def pin_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.pin_lora(lora_id)
824
825
826
827
828
829

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.list_loras()

830
    @torch.inference_mode()
831
    def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
832
833
834
835
836
837
838
839
840
841
842
843
        """Cuda graph capture a model.

        Note that CUDA graph's performance gain is negligible if number
        of batched tokens are larger than 200. And since CUDA graph
        requires fixed sized tensors, supporting large/variable batch
        size requires high GPU memory overhead. Thus, vLLM only captures
        decoding requests. Mixed batch (chunked prefill + decoding) or
        prefill requests are not captured.

        Since it is used for decoding-only, it assumes there's only 1 token
        per sequence in the batch.
        """
844
845
846
847
848
        assert not self.model_config.enforce_eager
        logger.info("Capturing the model for CUDA graphs. This may lead to "
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
                    "use '--enforce-eager' in the CLI.")
849
850
        logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
                    "If you are running out of memory, consider decreasing "
851
852
853
                    "`gpu_memory_utilization` or enforcing eager mode. "
                    "You can also reduce the `max_num_seqs` as needed "
                    "to decrease memory usage.")
854
855
856
857
        start_time = time.perf_counter()

        # Prepare dummy inputs. These will be reused for all batch sizes.
        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
858
859
860
        input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
861
        slot_mapping.fill_(_PAD_SLOT_ID)
862
        seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
863
864
        block_tables = torch.from_numpy(self.graph_block_tables).cuda()

865
866
867
868
        # Prepare buffer for outputs. These will be reused for all batch sizes.
        # It will be filled after the first graph capture.
        hidden_states: Optional[torch.Tensor] = None

869
870
871
872
873
874
        graph_batch_size = _get_graph_batch_size(
            self.scheduler_config.max_num_seqs)
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

875
        with graph_capture() as graph_capture_context:
876
877
            # NOTE: Capturing the largest batch size first may help reduce the
            # memory usage of CUDA graph.
878
            for batch_size in reversed(batch_size_capture_list):
879
                # Create dummy attn_metadata.
880
881
882
883
884
                attn_metadata = self.attn_backend.make_metadata(
                    num_prefills=0,
                    num_prefill_tokens=0,
                    num_decode_tokens=batch_size,
                    slot_mapping=slot_mapping[:batch_size],
885
886
887
                    seq_lens=None,
                    seq_lens_tensor=seq_lens[:batch_size],
                    max_query_len=None,
888
889
890
                    max_prefill_seq_len=0,
                    max_decode_seq_len=self.max_seq_len_to_capture,
                    query_start_loc=None,
891
                    seq_start_loc=None,
892
                    context_lens_tensor=None,
893
894
                    block_tables=block_tables[:batch_size],
                    use_cuda_graph=True,
895
                )
896

897
898
899
900
901
902
903
904
                if self.lora_config:
                    lora_mapping = LoRAMapping(
                        [0] * batch_size,
                        [0] * batch_size,
                    )
                    self.set_active_loras(set(), lora_mapping)

                graph_runner = CUDAGraphRunner(self.model)
905
                hidden_states = graph_runner.capture(
906
907
                    input_tokens[:batch_size],
                    input_positions[:batch_size],
908
909
                    hidden_states[:batch_size]
                    if hidden_states is not None else None,
910
                    kv_caches,
911
                    attn_metadata,
912
                    memory_pool=self.graph_memory_pool,
913
                    stream=graph_capture_context.stream,
914
                )
915
916
                self.graph_memory_pool = graph_runner.graph.pool()
                self.graph_runners[batch_size] = graph_runner
917
918
919
920

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        # This usually takes < 10 seconds.
921
        logger.info("Graph capturing finished in %.0f secs.", elapsed_time)
922

923
924
925
926
    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

927

928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
    """
    GPU model runner with sampling step.
    """
    _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
        ModelInputForGPUWithSamplingMetadata)

    def make_model_input_from_broadcasted_tensor_dict(
        self,
        tensor_dict: Dict[str, Any],
    ) -> ModelInputForGPUWithSamplingMetadata:
        return (
            ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
                tensor_dict,
                attn_backend=self.attn_backend,
            ))

    def prepare_model_input(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
    ) -> ModelInputForGPUWithSamplingMetadata:
        """Prepare the model input based on a given sequence group, including
        metadata for the sampling step.

        The API assumes seq_group_metadata_list is sorted by prefill -> decode.

        The result tensors and data structure also batches input in prefill
        -> decode order. For example,

        - input_tokens[:num_prefill_tokens] contains prefill tokens.
        - input_tokens[num_prefill_tokens:] contains decode tokens.

        If cuda graph is required, this API automatically pads inputs.
        """
        model_input = self._prepare_model_input_tensors(
            seq_group_metadata_list)
        sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
                                                     model_input.seq_lens,
                                                     model_input.query_lens,
                                                     self.device,
                                                     self.pin_memory)
        is_prompt = (seq_group_metadata_list[0].is_prompt
                     if seq_group_metadata_list else None)
        return dataclasses.replace(model_input,
                                   sampling_metadata=sampling_metadata,
                                   is_prompt=is_prompt)

    @torch.inference_mode()
    def execute_model(
        self,
        model_input: ModelInputForGPUWithSamplingMetadata,
        kv_caches: List[torch.Tensor],
    ) -> SamplerOutput:
        if self.lora_config:
            assert model_input.lora_requests is not None
            assert model_input.lora_mapping is not None
            self.set_active_loras(model_input.lora_requests,
                                  model_input.lora_mapping)

        # Currently cuda graph is only supported by the decode phase.
        assert model_input.attn_metadata is not None
        prefill_meta = model_input.attn_metadata.prefill_metadata
        decode_meta = model_input.attn_metadata.decode_metadata
        if prefill_meta is None and decode_meta.use_cuda_graph:
            assert model_input.input_tokens is not None
            graph_batch_size = model_input.input_tokens.shape[0]
            model_executable = self.graph_runners[graph_batch_size]
        else:
            model_executable = self.model

        multi_modal_kwargs = model_input.multi_modal_kwargs or {}
        hidden_states = model_executable(
            input_ids=model_input.input_tokens,
            positions=model_input.input_positions,
            kv_caches=kv_caches,
            attn_metadata=model_input.attn_metadata,
            **multi_modal_kwargs,
        )

        # Compute the logits.
        logits = self.model.compute_logits(hidden_states,
                                           model_input.sampling_metadata)

        # Only perform sampling in the driver worker.
        if not self.is_driver_worker:
            return None

        # Sample the next token.
        output: SamplerOutput = self.model.sample(
            logits=logits,
            sampling_metadata=model_input.sampling_metadata,
        )

        if self.return_hidden_states:
            # we only need to pass hidden states of most recent token
            if model_input.is_prompt:
                assert model_input.sampling_metadata is not None
                hidden_states = hidden_states.index_select(
                    0, model_input.sampling_metadata.selected_token_indices)
            output.hidden_states = hidden_states

        return output


1032
1033
1034
1035
1036
1037
1038
class CUDAGraphRunner:

    def __init__(self, model: nn.Module):
        self.model = model
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

1039
1040
1041
1042
1043
1044
1045
        self._graph: Optional[torch.cuda.CUDAGraph] = None

    @property
    def graph(self):
        assert self._graph is not None
        return self._graph

1046
1047
1048
1049
    def capture(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1050
        hidden_states: Optional[torch.Tensor],
1051
1052
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1053
1054
        memory_pool: Optional[Tuple[int, int]],
        stream: torch.cuda.Stream,
1055
        **kwargs,
1056
    ) -> torch.Tensor:
1057
        assert self._graph is None
1058
        # Run the model a few times without capturing the graph.
1059
1060
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
1061
1062
1063
1064
1065
1066
1067
1068
1069
        # Note one iteration is not enough for torch.jit.script
        for _ in range(_NUM_WARMUP_ITERS):
            self.model(
                input_ids,
                positions,
                kv_caches,
                attn_metadata,
                **kwargs,
            )
1070
1071
1072
1073
1074
        torch.cuda.synchronize()

        # Capture the graph.
        self._graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
1075
            output_hidden_states = self.model(
1076
1077
1078
                input_ids,
                positions,
                kv_caches,
1079
                attn_metadata,
1080
                **kwargs,
1081
            )
1082
1083
1084
1085
1086
1087
1088
1089
            if hidden_states is not None:
                hidden_states.copy_(output_hidden_states)
            else:
                hidden_states = output_hidden_states
            del output_hidden_states
            # make sure `output_hidden_states` is deleted
            # in the graph's memory pool
            gc.collect()
1090
1091
1092
1093
1094
1095
1096
        torch.cuda.synchronize()

        # Save the input and output buffers.
        self.input_buffers = {
            "input_ids": input_ids,
            "positions": positions,
            "kv_caches": kv_caches,
1097
            "slot_mapping": attn_metadata.slot_mapping,
1098
            "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
1099
            "block_tables": attn_metadata.decode_metadata.block_tables,
1100
1101
        }
        self.output_buffers = {"hidden_states": hidden_states}
1102
        return hidden_states
1103
1104
1105
1106
1107

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1108
1109
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1110
        **kwargs,
1111
1112
1113
1114
1115
    ) -> torch.Tensor:
        # KV caches are fixed tensors, so we don't need to copy them.
        del kv_caches

        # Copy the input tensors to the input buffers.
1116
1117
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
        self.input_buffers["positions"].copy_(positions, non_blocking=True)
1118
        self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
1119
                                                 non_blocking=True)
1120
1121
        self.input_buffers["seq_lens_tensor"].copy_(
            attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
1122
1123
        self.input_buffers["block_tables"].copy_(
            attn_metadata.decode_metadata.block_tables, non_blocking=True)
1124
1125
1126
1127
1128
1129
1130
1131
1132
        # Run the graph.
        self.graph.replay()

        # Return the output tensor.
        return self.output_buffers["hidden_states"]

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

1133

1134
def _get_graph_batch_size(batch_size: int) -> int:
1135
1136
1137
1138
1139
    """Returns the padded batch size given actual batch size.

    Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
    2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
    """
1140
1141
1142
1143
1144
    if batch_size <= 2:
        return batch_size
    elif batch_size <= 4:
        return 4
    else:
1145
1146
        return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
                _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
1147
1148


1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
def _is_block_tables_empty(block_tables: Union[None, Dict]):
    """
    Check if block_tables is None or a dictionary with all None values.
    """
    if block_tables is None:
        return True
    if isinstance(block_tables, dict) and all(
            value is None for value in block_tables.values()):
        return True
    return False