model_runner.py 61.3 KB
Newer Older
1
import dataclasses
2
import gc
3
import time
4
import warnings
5
from collections import defaultdict
6
7
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
                    TypeVar, Union)
8

9
import numpy as np
10
import torch
11
import torch.distributed
12
import torch.nn as nn
13

14
15
16
17
18
19
20
21
22
23
24
try:
    from flashinfer import BatchDecodeWithPagedKVCacheWrapper
    from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
    from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
    FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
except ImportError:
    BatchDecodeWithPagedKVCacheWrapper = None
    CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
    BatchPrefillWithPagedKVCacheWrapper = None
    FLASHINFER_WORKSPACE_BUFFER_SIZE = 0

25
from vllm.attention import AttentionMetadata, get_attn_backend
26
27
28
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
                         ModelConfig, ParallelConfig, SchedulerConfig,
                         VisionLanguageConfig)
29
from vllm.distributed import get_pp_group
30
from vllm.distributed.parallel_state import graph_capture
31
from vllm.inputs import INPUT_REGISTRY
32
from vllm.logger import init_logger
33
34
35
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
36
from vllm.model_executor import SamplingMetadata
37
from vllm.model_executor.model_loader import get_model
38
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
39
from vllm.model_executor.models.interfaces import supports_lora
40
from vllm.multimodal import MULTIMODAL_REGISTRY
41
from vllm.sampling_params import SamplingParams
42
43
from vllm.sequence import (IntermediateTensors, SamplerOutput,
                           SequenceGroupMetadata)
44
45
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
                        is_pin_memory_available, make_tensor_with_pad)
46
47
48
49
50
51
52
53
54
from vllm.worker.model_runner_base import (
    ModelRunnerBase, ModelRunnerInputBase,
    _add_attn_metadata_broadcastable_dict,
    _add_sampling_metadata_broadcastable_dict,
    _init_attn_metadata_from_tensor_dict,
    _init_sampling_metadata_from_tensor_dict)

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
55
56
57
58

logger = init_logger(__name__)

_PAD_SLOT_ID = -1
59
LORA_WARMUP_RANK = 8
60
61
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
62
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
63
64
65
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
    _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
66
_NUM_WARMUP_ITERS = 2
67

68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")


@dataclasses.dataclass(frozen=True)
class ModelInputForGPU(ModelRunnerInputBase):
    """
    This base class contains metadata needed for the base model forward pass
    but not metadata for possible additional steps, e.g., sampling. Model
    runners that run additional steps should subclass this method to add
    additional fields.
    """
    input_tokens: Optional[torch.Tensor] = None
    input_positions: Optional[torch.Tensor] = None
    seq_lens: Optional[List[int]] = None
    query_lens: Optional[List[int]] = None
    lora_mapping: Optional["LoRAMapping"] = None
    lora_requests: Optional[Set[LoRARequest]] = None
    attn_metadata: Optional["AttentionMetadata"] = None
    multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
87
    virtual_engine: int = 0
88
89
90
91
92
93
94
95

    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
            "lora_requests": self.lora_requests,
            "lora_mapping": self.lora_mapping,
            "multi_modal_kwargs": self.multi_modal_kwargs,
96
            "virtual_engine": self.virtual_engine,
97
98
99
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        return tensor_dict
100
101

    @classmethod
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    def from_broadcasted_tensor_dict(
        cls: Type[TModelInputForGPU],
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> TModelInputForGPU:
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


@dataclasses.dataclass(frozen=True)
class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
    """
    Used by the ModelRunner.
    """
    sampling_metadata: Optional["SamplingMetadata"] = None
    # Used for speculative decoding. We do not broadcast it because it is only
    # used by the driver worker.
    is_prompt: Optional[bool] = None

    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
            "lora_requests": self.lora_requests,
            "lora_mapping": self.lora_mapping,
            "multi_modal_kwargs": self.multi_modal_kwargs,
130
            "virtual_engine": self.virtual_engine,
131
132
133
134
135
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        _add_sampling_metadata_broadcastable_dict(tensor_dict,
                                                  self.sampling_metadata)
        return tensor_dict
136

137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    @classmethod
    def from_broadcasted_tensor_dict(
        cls,
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "ModelInputForGPUWithSamplingMetadata":
        tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
    """
    Helper class for shared methods between GPU model runners.
    """
    _model_input_cls: Type[TModelInputForGPU]
155
156
157
158
159
160

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
161
        device_config: DeviceConfig,
162
        cache_config: CacheConfig,
163
        load_config: LoadConfig,
164
        lora_config: Optional[LoRAConfig],
165
        kv_cache_dtype: Optional[str] = "auto",
166
        is_driver_worker: bool = False,
167
        vision_language_config: Optional[VisionLanguageConfig] = None,
168
        return_hidden_states: bool = False,
169
170
171
172
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
173
174
        self.device_config = device_config
        self.cache_config = cache_config
175
        self.lora_config = lora_config
176
        self.load_config = load_config
177
        self.is_driver_worker = is_driver_worker
178
        self.vision_language_config = vision_language_config
179
        self.return_hidden_states = return_hidden_states
180

181
        self.device = self.device_config.device
182
        self.pin_memory = is_pin_memory_available()
183

184
185
186
187
        self.kv_cache_dtype = kv_cache_dtype
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
188
189
190
191

        self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
            {} for _ in range(self.parallel_config.pipeline_parallel_size)
        ]
192
193
194
        self.graph_memory_pool: Optional[Tuple[
            int, int]] = None  # Set during graph capture.
        # When using CUDA graph, the input block tables must be padded to
195
        # max_seq_len_to_capture. However, creating the block table in
196
197
198
199
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
        # (max batch size to capture, max context len to capture / block size).
200
201
202
        self.graph_block_tables = np.zeros(
            (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
            dtype=np.int32)
203
204
        num_attn_heads = self.model_config.get_num_attention_heads(
            self.parallel_config)
205
        self.attn_backend = get_attn_backend(
206
            num_attn_heads,
207
208
209
210
211
212
            self.model_config.get_head_size(),
            self.model_config.get_num_kv_heads(self.parallel_config),
            self.model_config.get_sliding_window(),
            self.model_config.dtype,
            self.kv_cache_dtype,
            self.block_size,
213
        ) if num_attn_heads else None
214

215
216
217
        # Multi-modal data support
        self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
            .create_input_mapper(self.model_config)
218

219
        # Lazy initialization
220
        self.model: nn.Module  # Set after load_model
221
222
        # Set after load_model.
        self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
223

224
225
226
227
228
        self.flashinfer_decode_workspace_buffer = None
        self.flashinfer_decode_wrapper = None
        self.flashinfer_prefill_workspace_buffer = None
        self.flashinfer_prefill_wrapper = None

229
    def load_model(self) -> None:
230
        with CudaMemoryProfiler() as m:
231
            self.model = get_model(
232
233
234
                model_config=self.model_config,
                device_config=self.device_config,
                load_config=self.load_config,
235
236
237
                lora_config=self.lora_config,
                vision_language_config=self.vision_language_config,
                parallel_config=self.parallel_config,
238
                scheduler_config=self.scheduler_config,
239
                cache_config=self.cache_config,
240
            )
241
242

        self.model_memory_usage = m.consumed_memory
243
244
        logger.info("Loading model weights took %.4f GB",
                    self.model_memory_usage / float(2**30))
245
246

        if self.lora_config:
247
248
            assert supports_lora(self.model), "Model does not support LoRA"

249
250
            self.lora_manager = LRUCacheWorkerLoRAManager(
                self.scheduler_config.max_num_seqs,
251
252
253
254
255
256
257
258
259
                self.scheduler_config.max_num_batched_tokens,
                self.vocab_size,
                self.lora_config,
                self.device,
                self.model.embedding_modules,
                self.model.embedding_padding_modules,
                max_position_embeddings=self.model.config.
                max_position_embeddings,
            )
260
            self.model = self.lora_manager.create_lora_manager(self.model)
261

262
        if self.kv_cache_dtype == "fp8" and is_hip():
263
264
265
            # Currently only ROCm accepts kv-cache scaling factors
            # via quantization_param_path and this will be deprecated
            # in the future.
266
267
            if self.model_config.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
268
269
270
271
272
273
                    warnings.warn(
                        "Loading kv cache scaling factor from JSON is "
                        "deprecated and will be removed. Please include "
                        "kv cache scaling factors in the model checkpoint.",
                        FutureWarning,
                        stacklevel=2)
274
275
                    self.model.load_kv_cache_scales(
                        self.model_config.quantization_param_path)
276
277
                    logger.info("Loaded KV cache scaling factors from %s",
                                self.model_config.quantization_param_path)
278
                else:
279
280
281
282
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__)
283
            else:
284
285
286
287
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!")
288

289
290
291
292
293
294
295
296
297
298
299
300
301
302
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        from vllm.model_executor.model_loader.loader import ShardedStateLoader
        ShardedStateLoader.save_model(
            self.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

303
304
305
306
307
308
309
310
311
312
    def save_tensorized_model(
        self,
        tensorizer_config: TensorizerConfig,
    ) -> None:
        from vllm.model_executor.model_loader.loader import TensorizerLoader
        TensorizerLoader.save_model(
            self.model,
            tensorizer_config=tensorizer_config,
        )

313
314
    def get_max_block_per_batch(self) -> int:
        block_size = self.block_size
315
        return (self.max_seq_len_to_capture + block_size - 1) // block_size
316

317
    def _prepare_model_input_tensors(
318
319
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
320
321
322
323
    ) -> TModelInputForGPU:
        """Helper method to prepare the model input based on a given sequence
        group. Prepares metadata needed for the base model forward pass but not
        metadata for possible additional steps, e.g., sampling.
324
325
326
327
328
329
330
331
332
333
334

        The API assumes seq_group_metadata_list is sorted by prefill -> decode.

        The result tensors and data structure also batches input in prefill
        -> decode order. For example,

        - input_tokens[:num_prefill_tokens] contains prefill tokens.
        - input_tokens[num_prefill_tokens:] contains decode tokens.

        If cuda graph is required, this API automatically pads inputs.
        """
335
336
337
        input_tokens: List[int] = []
        input_positions: List[int] = []
        slot_mapping: List[int] = []
338
339
340
        lora_index_mapping: List[int] = []
        lora_prompt_mapping: List[int] = []
        lora_requests: Set[LoRARequest] = set()
341

342
        seq_lens: List[int] = []
343
344
        prefill_seq_lens: List[int] = []
        decode_seq_lens: List[int] = []
345
        context_lens: List[int] = []
346
        query_lens: List[int] = []
347
        block_tables: List[List[int]] = []
348
349
        multi_modal_kwargs_list: Dict[str,
                                      List[torch.Tensor]] = defaultdict(list)
350
351
352
353
        decode_only = True
        num_prefills = 0
        num_prefill_tokens = 0
        num_decode_tokens = 0
354

355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
        # The following fields are only for flashinfer
        # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
        # for the precise definition of the following fields.
        # An example:
        # request 1, page indices [0, 5, 8]
        # request 2, page indices [1, 6, 7]
        # request 3, page indices [3, 4]
        # paged_kv_indices is a concatenation of page indices of all requests:
        # [0, 5, 8, 1, 6, 7, 3, 4]
        # paged_kv_indptr is used to index into paged_kv_indices:
        # [0, 3, 6, 8]
        paged_kv_indices: List[int] = []
        # 0 at the beginning of paged_kv_indptr indicates the start of the
        # first request’s page indices in the paged_kv_indices list.
        paged_kv_indptr: List[int] = [0]
        # paged_kv_last_page_len is the length of the last page of each request
        paged_kv_last_page_len: List[int] = []

373
        if len(seq_group_metadata_list) == 0:
374
            return self._model_input_cls()
375

376
377
378
379
380
381
        if self.sliding_window is not None:
            sliding_window_blocks = (self.sliding_window + self.block_size -
                                     1) // self.block_size
            block_aligned_sliding_window = \
                sliding_window_blocks * self.block_size

382
383
        for seq_group_metadata in seq_group_metadata_list:
            seq_ids = list(seq_group_metadata.seq_data.keys())
384
            is_prompt = seq_group_metadata.is_prompt
385

386
            for seq_id in seq_ids:
387
388
389
390
391
392
393
394
395
                computed_block_nums = seq_group_metadata.computed_block_nums
                if (self.scheduler_config is not None
                        and self.scheduler_config.chunked_prefill_enabled
                        and not (computed_block_nums is None
                                 or computed_block_nums == [])):
                    raise RuntimeError(
                        "chunked prefill cannot be used with prefix caching "
                        "now.")

396
                seq_data = seq_group_metadata.seq_data[seq_id]
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
                if is_prompt:
                    context_len = seq_data.get_num_computed_tokens()
                else:
                    # get_num_computed_tokens is incorrect for spec decoding.
                    # So, we should have a special logic here.
                    # TODO(sang): Fix it.
                    context_len = seq_data.get_len() - 1

                seq_len = min(
                    seq_data.get_len(),
                    context_len + seq_group_metadata.token_chunk_size)
                if is_prompt:
                    tokens = seq_data.get_token_ids()[context_len:seq_len]
                else:
                    # Optimization. get_token_ids requires the entire copy of
                    # tokens.
                    tokens = [seq_data.get_last_token_id()]
414

415
416
417
418
419
420
421
                # Prefix cache was hit.
                # Prefix is not supported with sliding_window
                prefix_cache_hit = (computed_block_nums is not None
                                    and len(computed_block_nums) > 0
                                    and self.sliding_window is None
                                    and is_prompt)

422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
                # These are seq_len/context_len capped to the sliding window.
                # They are passed to decode kernel.
                # We still need original seq_len/context_len to compute slot
                # mapping (and input position) below.
                curr_sliding_window_blocks = None
                sliding_seq_len = seq_len
                sliding_context_len = context_len

                # TODO(sang): This is a hack to make sliding window work with
                # paged attn. We can remove it if we make paged attn kernel
                # to properly handle slinding window attn.
                if (self.sliding_window is not None and not is_prompt):
                    curr_sliding_window_blocks = sliding_window_blocks
                    if self.scheduler_config.use_v2_block_manager:
                        # number of elements in last block
                        suff_len = seq_len % self.block_size
                        sliding_seq_len = min(
                            seq_len, block_aligned_sliding_window + suff_len)
                        if suff_len > 0:
                            curr_sliding_window_blocks += 1
                    else:
                        sliding_seq_len = min(seq_len, self.sliding_window)
                    sliding_context_len = sliding_seq_len - 1

446
447
448
449
450
451
452
                # TODO(sang): Combine chunked prefill and prefix caching by
                # only allowing multiple of block_size chunk size.
                # NOTE: This only works for oooooooxxx style attention.
                if prefix_cache_hit:
                    assert computed_block_nums is not None
                    context_len = len(computed_block_nums) * self.block_size
                    tokens = tokens[context_len:]
453
454
455
456
457
458
459

                    # need to think what to set it to when we have both sliding
                    # window and prefix caching...
                    assert self.sliding_window is None, \
                        "Prefix caching is not supported with sliding window"
                    sliding_context_len = context_len

460
461
462
463
464
465
466
467
468
469
470
471
472
                    if self.attn_backend.get_name() == "flash-attn":
                        # NOTE(woosuk): For flash-attn, the block table should
                        # include the entries for the incoming prefill tokens.
                        # TODO(woosuk): This is a temporary fix. We should
                        # provide a unified interface for different backends.
                        block_table = seq_group_metadata.block_tables[seq_id]
                    else:
                        block_table = computed_block_nums
                elif (self.scheduler_config.chunked_prefill_enabled
                      or not is_prompt):
                    if seq_group_metadata.block_tables is not None:
                        # chunked prefill or decode
                        block_table = seq_group_metadata.block_tables[seq_id]
473
474
475
                        if curr_sliding_window_blocks is not None:
                            block_table = block_table[
                                -curr_sliding_window_blocks:]
476
477
478
479
480
481
482
483
                    else:
                        # Only happens when memory profiling runs.
                        block_table = []
                else:
                    # Prefill without chunked prefill or memory profiling.
                    block_table = []
                block_tables.append(block_table)

484
485
486
                seq_lens.append(sliding_seq_len)
                context_lens.append(sliding_context_len)
                query_len = sliding_seq_len - sliding_context_len
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
                query_lens.append(query_len)
                input_tokens.extend(tokens)
                input_positions.extend(list(range(context_len, seq_len)))
                lora_id = seq_group_metadata.lora_int_id

                if is_prompt:
                    assert len(seq_ids) == 1
                    num_prefills += 1
                    num_prefill_tokens += len(tokens)
                    decode_only = False
                    prefill_seq_lens.append(seq_len)
                else:
                    assert query_len == 1, (
                        "seq_len: {}, context_len: {}, query_len: {}".format(
                            seq_len, context_len, query_len))
                    num_decode_tokens += query_len
503
                    decode_seq_lens.append(sliding_seq_len)
504
505
506
507

                if lora_id > 0:
                    lora_requests.add(seq_group_metadata.lora_request)

508
                lora_index_mapping += [lora_id] * query_len
509
510
                lora_prompt_mapping.extend(
                    [lora_id] *
511
                    (query_len if seq_group_metadata.sampling_params
512
                     and seq_group_metadata.sampling_params.prompt_logprobs
513
                     is not None else 1))
514

515
                mm_data = seq_group_metadata.multi_modal_data
516
                if mm_data:
517
                    # Process multi-modal data
518
                    mm_kwargs = self.multi_modal_input_mapper(mm_data)
519
520
                    for k, v in mm_kwargs.items():
                        multi_modal_kwargs_list[k].append(v)
521

522
523
524
                is_profile_run = _is_block_tables_empty(
                    seq_group_metadata.block_tables)
                if is_profile_run:
525
526
527
528
529
530
                    # During memory profiling, the block tables are not
                    # initialized yet. In this case, we just use a dummy
                    # slot mapping.
                    # In embeddings, the block tables are {seq_id: None}.
                    slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
                    continue
531

532
                # Compute the slot mapping.
533
534
                block_table = seq_group_metadata.block_tables[seq_id]

535
536
537
538
539
540
541
                # Mask the [0, start_idx) tokens of the prompt with
                # _PAD_SLOT_ID, where start_idx is max(0, seq_len -
                # sliding_window). For example, if the prompt len is 10,
                # sliding window is 8, and block size is 4, the first two
                # tokens are masked and the slot mapping will be
                # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
                start_idx = 0
542
                if self.sliding_window is not None:
543
                    if is_prompt:
544
545
                        assert self.scheduler_config.use_v2_block_manager \
                            or context_len == 0, (
546
                            "Prefix caching is currently not supported with "
547
                            "sliding window attention in V1 block manager")
548
549
550
551
552
553
554
555
556
557
558
559
560
561
                    # It is an optimization. When it is decoding, it is always
                    # 0. When prefill, we use it to not write slots to kv cache
                    # to save memory.
                    start_idx = max(0, query_len - self.sliding_window)

                for i in range(context_len, seq_len):
                    if i < start_idx:
                        slot_mapping.append(_PAD_SLOT_ID)
                        continue

                    block_number = block_table[i // self.block_size]
                    block_offset = i % self.block_size
                    slot = block_number * self.block_size + block_offset
                    slot_mapping.append(slot)
562

563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
                # Prepare input tensors for flashinfer
                if self.attn_backend.get_name() == "flashinfer":
                    seq_len = seq_data.get_len()
                    # Get the number of valid blocks based on sequence length.
                    # If seq_len = 16, block_size = 16,
                    # block_table_bound is 1 with 1 valid block.
                    # If seq_len = 15, block_size = 16,
                    # block_table_bound is 0 + 1 with 1 valid block.
                    block_table_bound = seq_len // self.block_size + 1 \
                                        if seq_len % self.block_size != 0 \
                                        else seq_len // self.block_size

                    paged_kv_indices.extend(block_table[:block_table_bound])
                    paged_kv_indptr.append(paged_kv_indptr[-1] +
                                           block_table_bound)

                    last_page_len = seq_len % self.block_size
                    if last_page_len == 0:
                        last_page_len = self.block_size
                    paged_kv_last_page_len.append(last_page_len)

584
585
586
587
        batch_size = len(input_tokens)
        max_query_len = max(query_lens)
        max_prefill_seq_len = max(prefill_seq_lens, default=0)
        max_decode_seq_len = max(decode_seq_lens, default=0)
588

589
        # If cuda graph can be used, pad tensors accordingly.
590
        # See `capture_model` API for more details.
591
592
593
594
595
        # vLLM uses cuda graph only for decoding requests.
        use_captured_graph = (
            decode_only and not self.model_config.enforce_eager
            and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
            and max_decode_seq_len <= self.max_seq_len_to_capture)
596
597
598
599
        if use_captured_graph:
            graph_batch_size = _get_graph_batch_size(batch_size)
            assert graph_batch_size >= batch_size
            for _ in range(graph_batch_size - batch_size):
600
601
602
                input_tokens.append(0)
                input_positions.append(0)
                slot_mapping.append(_PAD_SLOT_ID)
603
                seq_lens.append(1)
604
                block_tables.append([])
605
                lora_index_mapping.append(0)
606
607
608
609
610
611

                if self.attn_backend.get_name() == "flashinfer":
                    last_paged_kv_indptr = paged_kv_indptr[-1]
                    paged_kv_indptr.append(last_paged_kv_indptr)
                    paged_kv_last_page_len.append(0)

612
            batch_size = graph_batch_size
613
            num_decode_tokens = batch_size
614
615
616
617
618
619
620
621

        if use_captured_graph:
            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
            input_block_tables = self.graph_block_tables[:batch_size]
            for i, block_table in enumerate(block_tables):
                if block_table:
                    input_block_tables[i, :len(block_table)] = block_table
622
            block_tables = torch.tensor(input_block_tables, device=self.device)
623
        else:
624
625
            max_block_table_len = max(
                len(block_table) for block_table in block_tables)
626
            block_tables = make_tensor_with_pad(
627
                block_tables,
628
                max_len=max_block_table_len,
629
630
                pad=0,
                dtype=torch.int,
631
                device=self.device,
632
            )
633
634
        assert max_query_len > 0, ("query_lens: {}".format(query_lens))

635
636
637
638
        context_lens_tensor = torch.tensor(context_lens,
                                           dtype=torch.int,
                                           device=self.device)

639
640
641
        seq_lens_tensor = torch.tensor(seq_lens,
                                       dtype=torch.int,
                                       device=self.device)
642
643
644
645
646
647
        query_lens_tensor = torch.tensor(query_lens,
                                         dtype=torch.long,
                                         device=self.device)
        query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
                                      dtype=torch.int32,
                                      device=self.device)
648
649
650
651
652
653
654
655
        seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
                                    dtype=torch.int32,
                                    device=self.device)

        torch.cumsum(seq_lens_tensor,
                     dim=0,
                     dtype=seq_start_loc.dtype,
                     out=seq_start_loc[1:])
656
657
658
659
        torch.cumsum(query_lens_tensor,
                     dim=0,
                     dtype=query_start_loc.dtype,
                     out=query_start_loc[1:])
660
661
662
663
664
665
666
667
668
669

        input_tokens_tensor = torch.tensor(input_tokens,
                                           dtype=torch.long,
                                           device=self.device)
        input_positions_tensor = torch.tensor(input_positions,
                                              dtype=torch.long,
                                              device=self.device)
        slot_mapping_tensor = torch.tensor(slot_mapping,
                                           dtype=torch.long,
                                           device=self.device)
670

671
        if self.attn_backend.get_name() == "flashinfer":
672
673
674
675
676
677
678
679
680
681
682
683
684
685
            if len(paged_kv_indptr) > 0:
                paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
                                                       device='cpu',
                                                       dtype=torch.int)
                paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr,
                                                      device='cpu',
                                                      dtype=torch.int)
                paged_kv_last_page_len_tensor = torch.tensor(
                    paged_kv_last_page_len, device='cpu', dtype=torch.int)
            else:
                paged_kv_indices_tensor = None
                paged_kv_indptr_tensor = None
                paged_kv_last_page_len_tensor = None

686
687
            kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
                                                      self.model_config.dtype)
688

689
            attn_metadata = self.attn_backend.make_metadata(
690
691
692
693
694
695
696
697
698
                num_prefills=num_prefills,
                slot_mapping=slot_mapping_tensor,
                num_prefill_tokens=num_prefill_tokens,
                num_decode_tokens=num_decode_tokens,
                max_prefill_seq_len=max_prefill_seq_len,
                block_tables=block_tables,
                paged_kv_indptr=paged_kv_indptr_tensor,
                paged_kv_indices=paged_kv_indices_tensor,
                paged_kv_last_page_len=paged_kv_last_page_len_tensor,
699
700
701
702
703
                num_qo_heads=self.model_config.get_num_attention_heads(
                    self.parallel_config),
                num_kv_heads=self.model_config.get_num_kv_heads(
                    self.parallel_config),
                head_dim=self.model_config.get_head_size(),
704
                page_size=self.block_size,
705
                seq_start_loc=seq_start_loc,
706
707
708
709
                query_start_loc=query_start_loc,
                device=self.device,
                data_type=kv_cache_dtype,
                use_cuda_graph=use_captured_graph)
710

711
        else:
712
            attn_metadata = self.attn_backend.make_metadata(
713
714
715
716
717
                num_prefills=num_prefills,
                slot_mapping=slot_mapping_tensor,
                num_prefill_tokens=num_prefill_tokens,
                num_decode_tokens=num_decode_tokens,
                seq_lens=seq_lens,
718
                seq_lens_tensor=seq_lens_tensor,
719
720
721
722
723
724
                max_query_len=max_query_len,
                max_prefill_seq_len=max_prefill_seq_len,
                max_decode_seq_len=max_decode_seq_len,
                query_start_loc=query_start_loc,
                seq_start_loc=seq_start_loc,
                context_lens_tensor=context_lens_tensor,
725
726
727
                block_tables=block_tables,
                use_cuda_graph=use_captured_graph,
            )
728
729
730
731
732
733
734
735
736

        if self.lora_config:
            lora_mapping = LoRAMapping(
                lora_index_mapping,
                lora_prompt_mapping,
            )
        else:
            lora_mapping = None

737
738
739
740
741
        multi_modal_kwargs = {
            k: torch.cat(v, dim=0).to(self.device)
            for k, v in multi_modal_kwargs_list.items()
        }

742
        return self._model_input_cls(
743
744
            input_tokens=input_tokens_tensor,
            input_positions=input_positions_tensor,
745
            attn_metadata=attn_metadata,
746
747
748
            seq_lens=seq_lens,
            query_lens=query_lens,
            lora_mapping=lora_mapping,
749
            lora_requests=lora_requests,
750
            multi_modal_kwargs=multi_modal_kwargs,
751
        )
752

753
754
755
    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
756
        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
757
758
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs
759
760
761
762
        # This represents the maximum number of different requests
        # that will have unique loras, an therefore the max amount of memory
        # consumption create dummy lora request copies from the lora request
        # passed in, which contains a lora from the lora warmup path.
763
764
        dummy_lora_requests: List[LoRARequest] = []
        dummy_lora_requests_per_seq: List[LoRARequest] = []
765
        if self.lora_config:
766
            assert self.lora_manager is not None
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
            with self.lora_manager.dummy_lora_cache():
                for idx in range(self.lora_config.max_loras):
                    lora_id = idx + 1
                    dummy_lora_request = LoRARequest(
                        lora_name=f"warmup_{lora_id}",
                        lora_int_id=lora_id,
                        lora_local_path="/not/a/real/path",
                    )
                    self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                     rank=LORA_WARMUP_RANK)
                    dummy_lora_requests.append(dummy_lora_request)
                dummy_lora_requests_per_seq = [
                    dummy_lora_requests[idx % len(dummy_lora_requests)]
                    for idx in range(max_num_seqs)
                ]
782

783
784
785
        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
786
787
788
789
790
791
        # Additional GPU memory may be needed for vision encoding, which needs
        # to be accounted for when calculating the GPU blocks for
        # vLLM blocker manager.
        # To exercise the worst scenario for GPU memory consumption,
        # the number of seqs (batch_size) is chosen to maximize the number
        # of images processed.
792
793
794
795
        model_config = self.model_config
        vlm_config = self.vision_language_config

        if vlm_config:
796
797
            max_num_seqs = min(
                max_num_seqs,
798
                int(max_num_batched_tokens / vlm_config.image_feature_size))
799
        batch_size = 0
800
801
802
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
803
            batch_size += seq_len
804

805
806
807
            seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
                .dummy_data_for_profiling(model_config, seq_len)
            assert len(seq_data.prompt_token_ids) == seq_len
808

809
810
811
812
813
814
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
815
816
                lora_request=dummy_lora_requests_per_seq[group_id]
                if dummy_lora_requests_per_seq else None,
817
                multi_modal_data=dummy_multi_modal_data,
818
819
820
821
822
            )
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
823
        kv_caches = [None] * num_layers
824
        model_input = self.prepare_model_input(seqs)
825
826
827
828
829
830
831
        intermediate_tensors = None
        if not get_pp_group().is_first_rank:
            intermediate_tensors = self.model.make_empty_intermediate_tensors(
                batch_size=batch_size,
                dtype=self.model_config.dtype,
                device=self.device)
        self.execute_model(model_input, kv_caches, intermediate_tensors)
832
        torch.cuda.synchronize()
833
834
        return

835
    def remove_all_loras(self):
836
837
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
838
        self.lora_manager.remove_all_loras()
839

840
    def set_active_loras(self, lora_requests: Set[LoRARequest],
841
842
843
844
845
846
847
848
849
850
851
852
853
854
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        self.lora_manager.set_active_loras(lora_requests, lora_mapping)

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.add_lora(lora_request)

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.remove_lora(lora_id)
855
856
857
858
859

    def pin_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.pin_lora(lora_id)
860
861
862
863
864
865

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
        return self.lora_manager.list_loras()

866
    @torch.inference_mode()
867
    def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
868
869
870
871
872
873
874
875
876
877
878
879
        """Cuda graph capture a model.

        Note that CUDA graph's performance gain is negligible if number
        of batched tokens are larger than 200. And since CUDA graph
        requires fixed sized tensors, supporting large/variable batch
        size requires high GPU memory overhead. Thus, vLLM only captures
        decoding requests. Mixed batch (chunked prefill + decoding) or
        prefill requests are not captured.

        Since it is used for decoding-only, it assumes there's only 1 token
        per sequence in the batch.
        """
880
881
882
883
884
        assert not self.model_config.enforce_eager
        logger.info("Capturing the model for CUDA graphs. This may lead to "
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
                    "use '--enforce-eager' in the CLI.")
885
886
        logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
                    "If you are running out of memory, consider decreasing "
887
888
889
                    "`gpu_memory_utilization` or enforcing eager mode. "
                    "You can also reduce the `max_num_seqs` as needed "
                    "to decrease memory usage.")
890
891
892
893
        start_time = time.perf_counter()

        # Prepare dummy inputs. These will be reused for all batch sizes.
        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
894
895
896
        input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
897
        slot_mapping.fill_(_PAD_SLOT_ID)
898
        seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
899
        block_tables = torch.from_numpy(self.graph_block_tables).cuda()
900
901
902
903
904
905
        intermediate_inputs = None
        if not get_pp_group().is_first_rank:
            intermediate_inputs = self.model.make_empty_intermediate_tensors(
                batch_size=max_batch_size,
                dtype=self.model_config.dtype,
                device=self.device)
906

907
908
        # Prepare buffer for outputs. These will be reused for all batch sizes.
        # It will be filled after the first graph capture.
909
910
911
        hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [
            None
        ] * self.parallel_config.pipeline_parallel_size
912

913
914
915
916
917
918
        graph_batch_size = _get_graph_batch_size(
            self.scheduler_config.max_num_seqs)
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
        if self.attn_backend.get_name() == "flashinfer":
            # For flashinfer, different batch sizes will share the
            # same workspace buffer.
            decode_workspace_buffer = \
            torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE,
                                                dtype=torch.uint8,
                                              device=self.device)
            indices_buffer = torch.empty(max_batch_size *
                                         self.cache_config.num_gpu_blocks,
                                         dtype=torch.int32,
                                         device=self.device)
            indptr_buffer = torch.empty(max_batch_size + 1,
                                        dtype=torch.int32,
                                        device=self.device)
            last_page_len_buffer = torch.empty(max_batch_size,
                                               dtype=torch.int32,
                                               device=self.device)

937
        with graph_capture() as graph_capture_context:
938
939
            # NOTE: Capturing the largest batch size first may help reduce the
            # memory usage of CUDA graph.
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
            for virtual_engine in range(
                    self.parallel_config.pipeline_parallel_size):
                for batch_size in reversed(batch_size_capture_list):
                    if self.attn_backend.get_name() == "flashinfer":
                        indptr_buffer = indptr_buffer[:batch_size + 1]
                        last_page_len_buffer = last_page_len_buffer[:
                                                                    batch_size]

                        num_qo_heads = (
                            self.model_config.get_num_attention_heads(
                                self.parallel_config))
                        num_kv_heads = self.model_config.get_num_kv_heads(
                            self.parallel_config)
                        if num_qo_heads // num_kv_heads >= 4:
                            use_tensor_cores = True
                        else:
                            use_tensor_cores = False
                        decode_wrapper = \
                            CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
                            decode_workspace_buffer, indptr_buffer,
                            indices_buffer, last_page_len_buffer, "NHD",
                            use_tensor_cores)
                        kv_cache_dtype = get_kv_cache_torch_dtype(
                            self.kv_cache_dtype, self.model_config.dtype)

                        paged_kv_indptr_tensor_host = torch.arange(
                            0, batch_size + 1, dtype=torch.int32)
                        paged_kv_indices_tensor_host = torch.arange(
                            0, batch_size, dtype=torch.int32)
                        paged_kv_last_page_len_tensor_host = torch.full(
                            (batch_size, ), self.block_size, dtype=torch.int32)
                        query_start_loc_host = torch.arange(0,
                                                            batch_size + 1,
                                                            dtype=torch.int32)

                        attn_metadata = self.attn_backend.make_metadata(
                            num_prefills=0,
                            slot_mapping=slot_mapping[:batch_size],
                            num_prefill_tokens=0,
                            num_decode_tokens=batch_size,
                            max_prefill_seq_len=0,
                            block_tables=block_tables,
                            paged_kv_indptr=paged_kv_indptr_tensor_host,
                            paged_kv_indices=paged_kv_indices_tensor_host,
                            paged_kv_last_page_len=
                            paged_kv_last_page_len_tensor_host,
                            num_qo_heads=num_qo_heads,
                            num_kv_heads=num_kv_heads,
                            head_dim=self.model_config.get_head_size(),
                            page_size=self.block_size,
                            seq_start_loc=None,
                            query_start_loc=query_start_loc_host,
                            device=self.device,
                            data_type=kv_cache_dtype,
                            use_cuda_graph=True,
                            decode_wrapper=decode_wrapper,
                            prefill_wrapper=None)
                        attn_metadata.begin_forward()
998
                    else:
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
                        attn_metadata = self.attn_backend.make_metadata(
                            num_prefills=0,
                            num_prefill_tokens=0,
                            num_decode_tokens=batch_size,
                            slot_mapping=slot_mapping[:batch_size],
                            seq_lens=None,
                            seq_lens_tensor=seq_lens[:batch_size],
                            max_query_len=None,
                            max_prefill_seq_len=0,
                            max_decode_seq_len=self.max_seq_len_to_capture,
                            query_start_loc=None,
                            seq_start_loc=None,
                            context_lens_tensor=None,
                            block_tables=block_tables[:batch_size],
                            use_cuda_graph=True,
                        )

                    if self.lora_config:
                        lora_mapping = LoRAMapping(
                            [0] * batch_size,
                            [0] * batch_size,
                        )
                        self.set_active_loras(set(), lora_mapping)

                    graph_runner = CUDAGraphRunner(
                        self.model, self.attn_backend.get_name())

                    if self.attn_backend.get_name() == "flashinfer":
                        graph_runner.flashinfer_indptr_buffer = indptr_buffer
                        graph_runner.flashinfer_indices_buffer = indices_buffer
                        graph_runner.flashinfer_last_page_len_buffer = \
                            last_page_len_buffer
                        graph_runner.flashinfer_decode_workspace_buffer = \
                                decode_workspace_buffer
                        graph_runner.flashinfer_decode_wrapper = \
                            decode_wrapper

                    graph_runner.capture(
                        input_tokens[:batch_size],
                        input_positions[:batch_size],
                        hidden_or_intermediate_states[
                            virtual_engine]  # type: ignore
                        [:batch_size]
                        if hidden_or_intermediate_states[virtual_engine]
                        is not None else None,
                        intermediate_inputs[:batch_size]
                        if intermediate_inputs is not None else None,
                        kv_caches[virtual_engine],
                        attn_metadata,
                        memory_pool=self.graph_memory_pool,
                        stream=graph_capture_context.stream,
1050
                    )
1051
1052
1053
                    self.graph_memory_pool = graph_runner.graph.pool()
                    self.graph_runners[virtual_engine][batch_size] = (
                        graph_runner)
1054
1055
1056
1057

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        # This usually takes < 10 seconds.
1058
        logger.info("Graph capturing finished in %.0f secs.", elapsed_time)
1059

1060
1061
1062
1063
    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

1064

1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
    """
    GPU model runner with sampling step.
    """
    _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
        ModelInputForGPUWithSamplingMetadata)

    def make_model_input_from_broadcasted_tensor_dict(
        self,
        tensor_dict: Dict[str, Any],
    ) -> ModelInputForGPUWithSamplingMetadata:
1076
        model_input = \
1077
1078
1079
            ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
                tensor_dict,
                attn_backend=self.attn_backend,
1080
1081
            )
        return model_input
1082
1083
1084
1085

    def prepare_model_input(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
1086
        virtual_engine: int = 0,
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
    ) -> ModelInputForGPUWithSamplingMetadata:
        """Prepare the model input based on a given sequence group, including
        metadata for the sampling step.

        The API assumes seq_group_metadata_list is sorted by prefill -> decode.

        The result tensors and data structure also batches input in prefill
        -> decode order. For example,

        - input_tokens[:num_prefill_tokens] contains prefill tokens.
        - input_tokens[num_prefill_tokens:] contains decode tokens.

        If cuda graph is required, this API automatically pads inputs.
        """
        model_input = self._prepare_model_input_tensors(
            seq_group_metadata_list)
        sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
                                                     model_input.seq_lens,
                                                     model_input.query_lens,
                                                     self.device,
                                                     self.pin_memory)
        is_prompt = (seq_group_metadata_list[0].is_prompt
                     if seq_group_metadata_list else None)
        return dataclasses.replace(model_input,
                                   sampling_metadata=sampling_metadata,
1112
1113
                                   is_prompt=is_prompt,
                                   virtual_engine=virtual_engine)
1114
1115
1116
1117
1118
1119

    @torch.inference_mode()
    def execute_model(
        self,
        model_input: ModelInputForGPUWithSamplingMetadata,
        kv_caches: List[torch.Tensor],
1120
        intermediate_tensors: Optional[IntermediateTensors] = None,
1121
        num_steps: int = 1,
1122
    ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
1123
1124
1125
        if num_steps > 1:
            raise ValueError("num_steps > 1 is not supported in ModelRunner")

1126
1127
1128
1129
1130
1131
        if self.lora_config:
            assert model_input.lora_requests is not None
            assert model_input.lora_mapping is not None
            self.set_active_loras(model_input.lora_requests,
                                  model_input.lora_mapping)

1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
        if self.attn_backend.get_name() == "flashinfer":
            assert model_input.attn_metadata is not None
            assert model_input.input_tokens is not None
            if self.flashinfer_decode_workspace_buffer is None:
                self.flashinfer_decode_workspace_buffer = torch.empty(
                    FLASHINFER_WORKSPACE_BUFFER_SIZE,
                    dtype=torch.uint8,
                    device=self.device)
                self.flashinfer_decode_wrapper = \
                    BatchDecodeWithPagedKVCacheWrapper(
                    self.flashinfer_decode_workspace_buffer, "NHD")
                self.flashinfer_prefill_workspace_buffer = torch.empty(
                    FLASHINFER_WORKSPACE_BUFFER_SIZE,
                    dtype=torch.uint8,
                    device=self.device)
                self.flashinfer_prefill_wrapper = \
                    BatchPrefillWithPagedKVCacheWrapper(
                    self.flashinfer_prefill_workspace_buffer, "NHD")

            model_input.attn_metadata.prefill_wrapper = \
                self.flashinfer_prefill_wrapper
            if model_input.attn_metadata.use_cuda_graph:
                batch_size = model_input.input_tokens.shape[0]
                model_input.attn_metadata.decode_wrapper = self.graph_runners[
                    batch_size].flashinfer_decode_wrapper
            else:
                model_input.attn_metadata.decode_wrapper = \
                    self.flashinfer_decode_wrapper
            model_input.attn_metadata.begin_forward()

1162
1163
1164
1165
        # Currently cuda graph is only supported by the decode phase.
        assert model_input.attn_metadata is not None
        prefill_meta = model_input.attn_metadata.prefill_metadata
        decode_meta = model_input.attn_metadata.decode_metadata
1166
1167
1168
        # TODO(andoorve): We can remove this once all
        # virtual engines share the same kv cache.
        virtual_engine = model_input.virtual_engine
1169
1170
1171
        if prefill_meta is None and decode_meta.use_cuda_graph:
            assert model_input.input_tokens is not None
            graph_batch_size = model_input.input_tokens.shape[0]
1172
1173
            model_executable = self.graph_runners[virtual_engine][
                graph_batch_size]
1174
1175
1176
1177
        else:
            model_executable = self.model

        multi_modal_kwargs = model_input.multi_modal_kwargs or {}
1178
        hidden_or_intermediate_states = model_executable(
1179
1180
1181
1182
            input_ids=model_input.input_tokens,
            positions=model_input.input_positions,
            kv_caches=kv_caches,
            attn_metadata=model_input.attn_metadata,
1183
            intermediate_tensors=intermediate_tensors,
1184
1185
1186
            **multi_modal_kwargs,
        )

1187
1188
1189
1190
1191
        # Compute the logits in the last pipeline stage.
        if not get_pp_group().is_last_rank:
            return hidden_or_intermediate_states

        logits = self.model.compute_logits(hidden_or_intermediate_states,
1192
1193
1194
                                           model_input.sampling_metadata)

        if not self.is_driver_worker:
1195
            return []
1196
1197
1198
1199
1200
1201
1202
1203
1204

        # Sample the next token.
        output: SamplerOutput = self.model.sample(
            logits=logits,
            sampling_metadata=model_input.sampling_metadata,
        )

        if self.return_hidden_states:
            # we only need to pass hidden states of most recent token
1205
1206
            assert model_input.sampling_metadata is not None
            indices = model_input.sampling_metadata.selected_token_indices
1207
            if model_input.is_prompt:
1208
1209
                hidden_states = hidden_or_intermediate_states.index_select(
                    0, indices)
1210
            elif decode_meta.use_cuda_graph:
1211
1212
1213
                hidden_states = hidden_or_intermediate_states[:len(indices)]
            else:
                hidden_states = hidden_or_intermediate_states
1214

1215
1216
            output.hidden_states = hidden_states

1217
        return [output]
1218
1219


1220
1221
class CUDAGraphRunner:

1222
    def __init__(self, model: nn.Module, backend_name: str):
1223
        self.model = model
1224
1225
        self.backend_name = backend_name

1226
1227
1228
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

1229
1230
        self._graph: Optional[torch.cuda.CUDAGraph] = None

1231
1232
1233
1234
1235
1236
1237
        self.flashinfer_decode_workspace_buffer: Optional[torch.Tensor] = None
        self.flashinfer_indptr_buffer: Optional[torch.Tensor] = None
        self.flashinfer_indices_buffer: Optional[torch.Tensor] = None
        self.flashinfer_last_page_len_buffer: Optional[torch.Tensor] = None
        self.flashinfer_decode_wrapper: Optional[
            CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None

1238
1239
1240
1241
1242
    @property
    def graph(self):
        assert self._graph is not None
        return self._graph

1243
1244
1245
1246
    def capture(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1247
1248
1249
        hidden_or_intermediate_states: Optional[Union[IntermediateTensors,
                                                      torch.Tensor]],
        intermediate_inputs: Optional[IntermediateTensors],
1250
1251
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1252
1253
        memory_pool: Optional[Tuple[int, int]],
        stream: torch.cuda.Stream,
1254
        **kwargs,
1255
    ) -> Union[torch.Tensor, IntermediateTensors]:
1256
        assert self._graph is None
1257
        # Run the model a few times without capturing the graph.
1258
1259
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
1260
1261
1262
1263
1264
1265
1266
        # Note one iteration is not enough for torch.jit.script
        for _ in range(_NUM_WARMUP_ITERS):
            self.model(
                input_ids,
                positions,
                kv_caches,
                attn_metadata,
1267
                intermediate_inputs,
1268
1269
                **kwargs,
            )
1270
1271
1272
1273
1274
        torch.cuda.synchronize()

        # Capture the graph.
        self._graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
1275
            output_hidden_or_intermediate_states = self.model(
1276
1277
1278
                input_ids,
                positions,
                kv_caches,
1279
                attn_metadata,
1280
                intermediate_inputs,
1281
                **kwargs,
1282
            )
1283
1284
1285
1286
1287
1288
1289
1290
            if hidden_or_intermediate_states is not None:
                if get_pp_group().is_last_rank:
                    hidden_or_intermediate_states.copy_(
                        output_hidden_or_intermediate_states)
                else:
                    for key in hidden_or_intermediate_states.tensors:
                        hidden_or_intermediate_states[key].copy_(
                            output_hidden_or_intermediate_states[key])
1291
            else:
1292
1293
1294
1295
                hidden_or_intermediate_states = (
                    output_hidden_or_intermediate_states)

            del output_hidden_or_intermediate_states
1296
1297
1298
            # make sure `output_hidden_states` is deleted
            # in the graph's memory pool
            gc.collect()
1299
1300
1301
        torch.cuda.synchronize()

        # Save the input and output buffers.
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
        if self.backend_name == "flashinfer":
            self.input_buffers = {
                "input_ids": input_ids,
                "positions": positions,
                "kv_caches": kv_caches,
                "slot_mapping": attn_metadata.slot_mapping,
            }
        else:
            self.input_buffers = {
                "input_ids": input_ids,
                "positions": positions,
                "kv_caches": kv_caches,
                "slot_mapping": attn_metadata.slot_mapping,
                "seq_lens_tensor":
                attn_metadata.decode_metadata.seq_lens_tensor,
                "block_tables": attn_metadata.decode_metadata.block_tables,
            }
1319
1320
1321
1322
1323
1324
1325
1326
1327
        if intermediate_inputs is not None:
            self.input_buffers.update(intermediate_inputs.tensors)
        if get_pp_group().is_last_rank:
            self.output_buffers = {
                "hidden_states": hidden_or_intermediate_states
            }
        else:
            self.output_buffers = hidden_or_intermediate_states
        return hidden_or_intermediate_states
1328
1329
1330
1331
1332

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1333
1334
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1335
        intermediate_tensors: Optional[IntermediateTensors],
1336
        **kwargs,
1337
1338
1339
1340
1341
    ) -> torch.Tensor:
        # KV caches are fixed tensors, so we don't need to copy them.
        del kv_caches

        # Copy the input tensors to the input buffers.
1342
1343
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
        self.input_buffers["positions"].copy_(positions, non_blocking=True)
1344
        self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
1345
                                                 non_blocking=True)
1346
1347
1348
1349
1350
1351
        if self.backend_name != "flashinfer":
            self.input_buffers["seq_lens_tensor"].copy_(
                attn_metadata.decode_metadata.seq_lens_tensor,
                non_blocking=True)
            self.input_buffers["block_tables"].copy_(
                attn_metadata.decode_metadata.block_tables, non_blocking=True)
1352
1353
1354
1355
        if intermediate_tensors is not None:
            for key in intermediate_tensors.tensors:
                self.input_buffers[key].copy_(intermediate_tensors[key],
                                              non_blocking=True)
1356
1357
1358
1359
        # Run the graph.
        self.graph.replay()

        # Return the output tensor.
1360
1361
1362
1363
        if get_pp_group().is_last_rank:
            return self.output_buffers["hidden_states"]

        return self.output_buffers
1364
1365
1366
1367

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

1368

1369
def _get_graph_batch_size(batch_size: int) -> int:
1370
1371
1372
1373
1374
    """Returns the padded batch size given actual batch size.

    Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
    2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
    """
1375
1376
1377
1378
1379
    if batch_size <= 2:
        return batch_size
    elif batch_size <= 4:
        return 4
    else:
1380
1381
        return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
                _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
1382
1383


1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
def _is_block_tables_empty(block_tables: Union[None, Dict]):
    """
    Check if block_tables is None or a dictionary with all None values.
    """
    if block_tables is None:
        return True
    if isinstance(block_tables, dict) and all(
            value is None for value in block_tables.values()):
        return True
    return False