model_runner.py 67.1 KB
Newer Older
1
import dataclasses
2
import gc
3
import time
4
import warnings
5
import weakref
6
from dataclasses import dataclass, field
7
8
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
                    Tuple, Type, TypeVar, Union)
9

10
import numpy as np
11
import torch
12
import torch.distributed
13
import torch.nn as nn
14

15
16
17
18
try:
    from flashinfer import BatchDecodeWithPagedKVCacheWrapper
    from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
    from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
19
    FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
20
21
22
23
24
25
except ImportError:
    BatchDecodeWithPagedKVCacheWrapper = None
    CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
    BatchPrefillWithPagedKVCacheWrapper = None
    FLASHINFER_WORKSPACE_BUFFER_SIZE = 0

26
from vllm.attention import AttentionMetadata, get_attn_backend
27
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
28
                         ModelConfig, MultiModalConfig, ParallelConfig,
29
                         PromptAdapterConfig, SchedulerConfig)
30
from vllm.distributed import get_pp_group
31
from vllm.distributed.parallel_state import graph_capture
32
from vllm.inputs import INPUT_REGISTRY
33
from vllm.logger import init_logger
34
35
36
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
37
from vllm.model_executor import SamplingMetadata
38
from vllm.model_executor.model_loader import get_model
39
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
40
41
from vllm.model_executor.models.interfaces import (supports_lora,
                                                   supports_vision)
42
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
43
44
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
                             MultiModalInputs)
45
46
47
48
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import (
    LRUCacheWorkerPromptAdapterManager)
49
from vllm.sampling_params import SamplingParams
50
51
from vllm.sequence import (IntermediateTensors, SamplerOutput,
                           SequenceGroupMetadata)
52
53
from vllm.utils import (CudaMemoryProfiler, flatten_2d_lists,
                        get_kv_cache_torch_dtype, is_hip,
54
                        is_pin_memory_available)
55
from vllm.worker.model_runner_base import (
56
    ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
57
58
59
60
61
62
63
    _add_attn_metadata_broadcastable_dict,
    _add_sampling_metadata_broadcastable_dict,
    _init_attn_metadata_from_tensor_dict,
    _init_sampling_metadata_from_tensor_dict)

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
64
65
66
67

logger = init_logger(__name__)

_PAD_SLOT_ID = -1
68
LORA_WARMUP_RANK = 8
69
70
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
71
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
72
73
74
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
    _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
75
_NUM_WARMUP_ITERS = 2
76

77
78
79
TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")


80
@dataclass(frozen=True)
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class ModelInputForGPU(ModelRunnerInputBase):
    """
    This base class contains metadata needed for the base model forward pass
    but not metadata for possible additional steps, e.g., sampling. Model
    runners that run additional steps should subclass this method to add
    additional fields.
    """
    input_tokens: Optional[torch.Tensor] = None
    input_positions: Optional[torch.Tensor] = None
    seq_lens: Optional[List[int]] = None
    query_lens: Optional[List[int]] = None
    lora_mapping: Optional["LoRAMapping"] = None
    lora_requests: Optional[Set[LoRARequest]] = None
    attn_metadata: Optional["AttentionMetadata"] = None
95
96
    prompt_adapter_mapping: Optional[PromptAdapterMapping] = None
    prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None
97
    multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
Mor Zusman's avatar
Mor Zusman committed
98
99
    request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
    finished_requests_ids: Optional[List[str]] = None
100
    virtual_engine: int = 0
101
102
103
104
105
106
107
108

    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
            "lora_requests": self.lora_requests,
            "lora_mapping": self.lora_mapping,
            "multi_modal_kwargs": self.multi_modal_kwargs,
109
110
            "prompt_adapter_mapping": self.prompt_adapter_mapping,
            "prompt_adapter_requests": self.prompt_adapter_requests,
111
            "virtual_engine": self.virtual_engine,
Mor Zusman's avatar
Mor Zusman committed
112
113
            "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
            "finished_requests_ids": self.finished_requests_ids,
114
115
116
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        return tensor_dict
117
118

    @classmethod
119
120
121
122
123
124
125
126
127
128
129
    def from_broadcasted_tensor_dict(
        cls: Type[TModelInputForGPU],
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> TModelInputForGPU:
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


130
@dataclass(frozen=True)
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
    """
    Used by the ModelRunner.
    """
    sampling_metadata: Optional["SamplingMetadata"] = None
    # Used for speculative decoding. We do not broadcast it because it is only
    # used by the driver worker.
    is_prompt: Optional[bool] = None

    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
            "lora_requests": self.lora_requests,
            "lora_mapping": self.lora_mapping,
            "multi_modal_kwargs": self.multi_modal_kwargs,
147
148
            "prompt_adapter_mapping": self.prompt_adapter_mapping,
            "prompt_adapter_requests": self.prompt_adapter_requests,
149
            "virtual_engine": self.virtual_engine,
Mor Zusman's avatar
Mor Zusman committed
150
151
            "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
            "finished_requests_ids": self.finished_requests_ids,
152
153
154
155
156
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        _add_sampling_metadata_broadcastable_dict(tensor_dict,
                                                  self.sampling_metadata)
        return tensor_dict
157

158
159
160
161
162
163
164
165
166
167
168
169
170
    @classmethod
    def from_broadcasted_tensor_dict(
        cls,
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "ModelInputForGPUWithSamplingMetadata":
        tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


171
class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
    """Build ModelInputForGPU from SequenceGroupMetadata."""

    @dataclass
    class InterDataForSeqGroup:
        """Intermediate data for the current sequence group."""
        # From sequence group metadata.
        request_id: str
        seq_ids: List[int]
        is_prompt: bool
        block_tables: Optional[Dict[int, List[int]]]
        computed_block_nums: List[int]
        n_seqs: int = 0

        # Input tokens and positions.
        input_tokens: List[List[int]] = field(default_factory=list)
        input_positions: List[List[int]] = field(default_factory=list)

        # The sequence length (may be capped to the sliding window).
        seq_lens: List[int] = field(default_factory=list)
        # The original sequence length (before applying sliding window).
        # This is used to compute slot mapping.
        orig_seq_lens: List[int] = field(default_factory=list)
        # The query length.
        query_lens: List[int] = field(default_factory=list)
        # The number of tokens that are already computed.
        context_lens: List[int] = field(default_factory=list)
        # The current sliding window block.
        curr_sliding_window_blocks: List[int] = field(default_factory=list)

        # LoRA inputs.
        lora_index_mapping: List[List[int]] = field(default_factory=list)
        lora_prompt_mapping: List[List[int]] = field(default_factory=list)
        lora_requests: Set[LoRARequest] = field(default_factory=set)

        # Prompt adapter inputs.
        prompt_adapter_index_mapping: List[int] = field(default_factory=list)
        prompt_adapter_prompt_mapping: List[int] = field(default_factory=list)
        prompt_adapter_request: Optional[PromptAdapterRequest] = None

        # Multi-modal inputs.
        multi_modal_inputs: Optional[MultiModalInputs] = None

        # Whether the prefix cache is hit (prefill only).
        prefix_cache_hit: bool = False

        def __post_init__(self):
            self.n_seqs = len(self.seq_ids)

            self.input_tokens = [[] for _ in range(self.n_seqs)]
            self.input_positions = [[] for _ in range(self.n_seqs)]
            self.seq_lens = [0] * self.n_seqs
            self.orig_seq_lens = [0] * self.n_seqs
            self.query_lens = [0] * self.n_seqs
            self.context_lens = [0] * self.n_seqs
            self.curr_sliding_window_blocks = [0] * self.n_seqs

            self.lora_index_mapping = [[] for _ in range(self.n_seqs)]
            self.lora_prompt_mapping = [[] for _ in range(self.n_seqs)]
230
231
232
233
234

    def __init__(self,
                 runner: "GPUModelRunnerBase",
                 finished_requests_ids: Optional[List[str]] = None):
        super().__init__()
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
        # Compute functions for each sequence in a sequence group.
        # WARNING: The order of the functions matters!
        self.per_seq_compute_fns = [
            self._compute_lens,
            self._compute_for_prefix_cache_hit,
            self._compute_for_sliding_window,
            self._compute_lora_input,
        ]
        # Compute functions for each sequence group.
        # WARNING: The order of the functions matters!
        self.per_seq_group_compute_fns = [
            self._compute_prompt_adapter_input,
            self._compute_multi_modal_input,
        ]

250
251
252
253
254
255
256
257
258
259
260
261
262
        self.runner = runner
        self.model_input_cls = self.runner._model_input_cls
        self.attn_backend = self.runner.attn_backend
        self.scheduler_config = self.runner.scheduler_config
        self.sliding_window = self.runner.sliding_window
        self.block_size = self.runner.block_size
        self.enable_lora = self.runner.lora_config is not None
        self.enable_prompt_adapter = (self.runner.prompt_adapter_config
                                      is not None)
        self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
        self.finished_requests_ids = finished_requests_ids
        self.decode_only = True

263
264
265
266
        # Intermediate data (data in CPU before going to GPU) for
        # the current sequence group.
        self.inter_data_list: List[
            ModelInputForGPUBuilder.InterDataForSeqGroup] = []
267
268
269

        # Attention metadata inputs.
        self.attn_metadata_builder = self.attn_backend.make_metadata_builder(
270
            weakref.proxy(self))
271
272
273
274
275
276
277
278
279
280
281

        # Engine/Model configurations.
        self.chunked_prefill_enabled = (
            self.scheduler_config is not None
            and self.scheduler_config.chunked_prefill_enabled)
        if self.sliding_window is not None:
            self.sliding_window_blocks = (
                self.sliding_window + self.block_size - 1) // self.block_size
            self.block_aligned_sliding_window = \
                self.sliding_window_blocks * self.block_size

282
283
284
285
286
287
288
    def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
                      seq_group_metadata: SequenceGroupMetadata):
        """Compute context length, sequence length and tokens
        for the given sequence data.
        """
        seq_data = seq_group_metadata.seq_data[inter_data.seq_ids[seq_idx]]
        token_chunk_size = seq_group_metadata.token_chunk_size
289

290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
        # Compute context length (the number of tokens that are
        # already computed) and sequence length (total number of tokens).
        seq_len = seq_data.get_len()
        if inter_data.is_prompt:
            context_len = seq_data.get_num_computed_tokens()
        else:
            # get_num_computed_tokens is incorrect for spec decoding.
            # So, we should have a special logic here.
            # TODO(sang): Fix it.
            context_len = seq_len - 1
        seq_len = min(seq_len, context_len + token_chunk_size)

        # Compute tokens.
        if inter_data.is_prompt:
            tokens = seq_data.get_token_ids()[context_len:seq_len]
        else:
            # Optimization. get_token_ids requires the entire copy of
            # tokens.
            tokens = [seq_data.get_last_token_id()]

        inter_data.seq_lens[seq_idx] = seq_len
        inter_data.orig_seq_lens[seq_idx] = seq_len
        inter_data.context_lens[seq_idx] = context_len
        inter_data.input_tokens[seq_idx] = tokens
        inter_data.input_positions[seq_idx] = list(range(context_len, seq_len))
        inter_data.query_lens[
            seq_idx] = seq_len - context_len if inter_data.is_prompt else 1

    def _compute_for_prefix_cache_hit(
            self, inter_data: InterDataForSeqGroup, seq_idx: int,
            seq_group_metadata: SequenceGroupMetadata):
        """Check if hit prefix cache (i.e., some blocks are already computed).
        If hit, update input tokens and positions to only compute the
        remaining blocks.
        """
        computed_block_nums = inter_data.computed_block_nums

        # Note that prefix caching does not support sliding window.
        prefix_cache_hit = (computed_block_nums is not None
                            and len(computed_block_nums) > 0
                            and self.sliding_window is None
                            and inter_data.is_prompt)
        inter_data.prefix_cache_hit = prefix_cache_hit
        if self.chunked_prefill_enabled and prefix_cache_hit:
            raise RuntimeError(
                "chunked prefill cannot be used with prefix caching now.")

        # If prefix cache is hit, advance context length to bypass
        # hit blocks. Accordingly, input tokens, position and query length
        # have to be updated.
        if prefix_cache_hit:
            assert computed_block_nums is not None
            context_len = len(computed_block_nums) * self.block_size
            inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
                seq_idx][context_len:]
            inter_data.input_positions[seq_idx] = inter_data.input_positions[
                seq_idx][context_len:]
            inter_data.context_lens[seq_idx] = context_len
            inter_data.query_lens[
                seq_idx] = inter_data.seq_lens[seq_idx] - context_len

    def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup,
                                    seq_idx: int,
                                    seq_group_metadata: SequenceGroupMetadata):
        """Update seq_len and curr_sliding_window_block for the given
        sequence data (only required by decoding) if sliding window is enabled.
        """
        curr_sliding_window_block = 0
        sliding_seq_len = inter_data.seq_lens[seq_idx]
        if not inter_data.is_prompt and self.sliding_window is not None:
            # TODO(sang): This is a hack to make sliding window work with
            # paged attn. We can remove it if we make paged attn kernel
            # to properly handle slinding window attn.
            curr_sliding_window_block = self.sliding_window_blocks
364
365
            if self.scheduler_config.use_v2_block_manager:
                # number of elements in last block
366
                suff_len = inter_data.seq_lens[seq_idx] % self.block_size
367
                sliding_seq_len = min(
368
369
                    inter_data.seq_lens[seq_idx],
                    self.block_aligned_sliding_window + suff_len)
370
                if suff_len > 0:
371
                    curr_sliding_window_block += 1
372
            else:
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
                sliding_seq_len = min(inter_data.seq_lens[seq_idx],
                                      self.sliding_window)

        inter_data.curr_sliding_window_blocks[
            seq_idx] = curr_sliding_window_block
        inter_data.seq_lens[seq_idx] = sliding_seq_len

    def _compute_lora_input(self, inter_data: InterDataForSeqGroup,
                            seq_idx: int,
                            seq_group_metadata: SequenceGroupMetadata):
        """If LoRA is enabled, compute LoRA index and prompt mapping."""
        if not self.enable_lora:
            return

        lora_id = seq_group_metadata.lora_int_id
        if lora_id > 0:
            inter_data.lora_requests.add(seq_group_metadata.lora_request)
        query_len = inter_data.query_lens[seq_idx]
        inter_data.lora_index_mapping.append([lora_id] * query_len)
        inter_data.lora_prompt_mapping.append(
            [lora_id] *
            (query_len if seq_group_metadata.sampling_params
             and seq_group_metadata.sampling_params.prompt_logprobs is not None
             else 1))

    def _compute_prompt_adapter_input(
            self, inter_data: InterDataForSeqGroup,
            seq_group_metadata: SequenceGroupMetadata):
        """If prompt adapter is enabled, compute index and prompt mapping.
        """
        # Note that when is_prompt=True, we expect only one sequence
        # in the group.
        if not self.enable_prompt_adapter:
            return

        prompt_adapter_id = seq_group_metadata.prompt_adapter_id
        if prompt_adapter_id <= 0 or not inter_data.is_prompt:
            return

        # We expect only one sequence in the group when is_prompt=True.
        assert inter_data.n_seqs == 1
        query_len = inter_data.query_lens[0]
        inter_data.prompt_adapter_request = (
            seq_group_metadata.prompt_adapter_request)

        num_tokens = seq_group_metadata.prompt_adapter_num_virtual_tokens
        inter_data.prompt_adapter_index_mapping = [
            prompt_adapter_id
        ] * num_tokens + [0] * (query_len - num_tokens)
        inter_data.prompt_adapter_prompt_mapping = [prompt_adapter_id] * (
            query_len if seq_group_metadata.sampling_params
            and seq_group_metadata.sampling_params.prompt_logprobs else 1)

    def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
                                   seq_group_metadata: SequenceGroupMetadata):
        """If multi-modal data is given, add it to the input."""
        mm_data = seq_group_metadata.multi_modal_data
        if not mm_data:
            return

        mm_kwargs = self.multi_modal_input_mapper(mm_data)
        inter_data.multi_modal_inputs = mm_kwargs
435
436

    def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
437
        """Add a sequence group to the builder."""
438
439
440
441
442
443
444
445
        seq_ids = list(seq_group_metadata.seq_data.keys())
        n_seqs = len(seq_ids)
        is_prompt = seq_group_metadata.is_prompt

        if is_prompt:
            assert n_seqs == 1
            self.decode_only = False

446
447
448
449
450
451
452
        inter_data = self.InterDataForSeqGroup(
            request_id=seq_group_metadata.request_id,
            seq_ids=seq_ids,
            is_prompt=is_prompt,
            block_tables=seq_group_metadata.block_tables,
            computed_block_nums=seq_group_metadata.computed_block_nums)
        self.inter_data_list.append(inter_data)
453

454
455
456
457
458
        for seq_idx in range(n_seqs):
            for per_seq_fn in self.per_seq_compute_fns:
                per_seq_fn(inter_data, seq_idx, seq_group_metadata)
        for per_seq_group_fn in self.per_seq_group_compute_fns:
            per_seq_group_fn(inter_data, seq_group_metadata)
459
460

    def build(self) -> ModelInputForGPU:
461
462
463
464
465
466
467
468
469
470
471
        """Finalize the builder intermediate data and
        create on-device tensors.
        """
        # Combine and flatten intermediate data.
        input_tokens = flatten_2d_lists([
            flatten_2d_lists(inter_data.input_tokens)
            for inter_data in self.inter_data_list
        ])
        if not input_tokens:
            # This may happen when all prefill requests hit
            # prefix caching and there is no decode request.
472
            return self.model_input_cls()
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
        input_positions = flatten_2d_lists([
            flatten_2d_lists(inter_data.input_positions)
            for inter_data in self.inter_data_list
        ])
        seq_lens = []
        max_decode_seq_len = 0
        for inter_data in self.inter_data_list:
            seq_lens.extend(inter_data.seq_lens)
            if not inter_data.is_prompt:
                max_decode_seq_len = max(max_decode_seq_len,
                                         max(inter_data.seq_lens))
        query_lens = flatten_2d_lists(
            [inter_data.query_lens for inter_data in self.inter_data_list])
        # Mapping from request IDs to sequence IDs. Used for Jamba models
        # that manages the cache by itself.
        request_ids_to_seq_ids = {
            data.request_id: data.seq_ids
            for data in self.inter_data_list
        }
492

493
        batch_size = len(input_tokens)
494
495
496
        use_captured_graph = (
            self.decode_only and not self.runner.model_config.enforce_eager
            and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
497
            and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
498
499
500
501
502
503
504
505
506
507
508
509

        # If cuda graph can be used, pad tensors accordingly.
        # See `capture_model` API for more details.
        # vLLM uses cuda graph only for decoding requests.
        cuda_graph_pad_size = -1
        if use_captured_graph:
            graph_batch_size = _get_graph_batch_size(batch_size)
            assert graph_batch_size >= batch_size
            cuda_graph_pad_size = graph_batch_size - batch_size
            batch_size = graph_batch_size

        # Tokens and positions.
510
511
512
        input_tokens.extend([0] * cuda_graph_pad_size)
        input_positions.extend([0] * cuda_graph_pad_size)
        input_tokens_tensor = torch.tensor(input_tokens,
513
514
                                           dtype=torch.long,
                                           device=self.runner.device)
515
        input_positions_tensor = torch.tensor(input_positions,
516
517
518
519
                                              dtype=torch.long,
                                              device=self.runner.device)

        # Sequence and query lengths.
520
        seq_lens.extend([1] * cuda_graph_pad_size)
521
522
523

        # Attention metadata.
        attn_metadata = self.attn_metadata_builder.build(
524
            seq_lens, query_lens, cuda_graph_pad_size, batch_size)
525
526

        # LoRA data.
527
528
        lora_requests = set()
        lora_mapping = None
529
        if self.enable_lora:
530
531
532
533
534
535
536
537
538
539
540
            lora_requests = set(r for data in self.inter_data_list
                                for r in data.lora_requests)
            lora_index_mapping = flatten_2d_lists([
                flatten_2d_lists(inter_data.lora_index_mapping)
                for inter_data in self.inter_data_list
            ])
            lora_index_mapping.extend([0] * cuda_graph_pad_size)
            lora_prompt_mapping = flatten_2d_lists([
                flatten_2d_lists(inter_data.lora_prompt_mapping)
                for inter_data in self.inter_data_list
            ])
541
            lora_mapping = LoRAMapping(
542
543
                lora_index_mapping,
                lora_prompt_mapping,
544
545
546
            )

        # Prompt adapter data.
547
548
        prompt_adapter_requests: Set[PromptAdapterRequest] = set()
        prompt_adapter_mapping = None
549
        if self.enable_prompt_adapter:
550
551
552
553
554
555
556
557
558
559
560
561
            prompt_adapter_requests = set(
                data.prompt_adapter_request for data in self.inter_data_list
                if data.prompt_adapter_request is not None)
            prompt_adapter_index_mapping = flatten_2d_lists([
                inter_data.prompt_adapter_index_mapping
                for inter_data in self.inter_data_list
            ])
            prompt_adapter_index_mapping.extend([0] * cuda_graph_pad_size)
            prompt_adapter_prompt_mapping = flatten_2d_lists([
                inter_data.prompt_adapter_prompt_mapping
                for inter_data in self.inter_data_list
            ])
562
            prompt_adapter_mapping = PromptAdapterMapping(
563
564
                prompt_adapter_index_mapping,
                prompt_adapter_prompt_mapping,
565
566
567
            )

        # Multi-modal data.
568
569
570
571
572
573
        multi_modal_inputs_list = [
            data.multi_modal_inputs for data in self.inter_data_list
            if data.multi_modal_inputs is not None
        ]
        multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
                                                    device=self.runner.device)
574
575
576
577
578

        return self.model_input_cls(
            input_tokens=input_tokens_tensor,
            input_positions=input_positions_tensor,
            attn_metadata=attn_metadata,
579
580
            seq_lens=seq_lens,
            query_lens=query_lens,
581
            lora_mapping=lora_mapping,
582
            lora_requests=lora_requests,
583
            multi_modal_kwargs=multi_modal_kwargs,
584
            request_ids_to_seq_ids=request_ids_to_seq_ids,
585
586
            finished_requests_ids=self.finished_requests_ids,
            prompt_adapter_mapping=prompt_adapter_mapping,
587
            prompt_adapter_requests=prompt_adapter_requests)
588
589


590
591
592
593
594
class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
    """
    Helper class for shared methods between GPU model runners.
    """
    _model_input_cls: Type[TModelInputForGPU]
595
596
597
598
599
600

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
601
        device_config: DeviceConfig,
602
        cache_config: CacheConfig,
603
        load_config: LoadConfig,
604
        lora_config: Optional[LoRAConfig],
605
        kv_cache_dtype: Optional[str] = "auto",
606
        is_driver_worker: bool = False,
607
        prompt_adapter_config: Optional[PromptAdapterConfig] = None,
608
        multimodal_config: Optional[MultiModalConfig] = None,
609
        return_hidden_states: bool = False,
610
611
612
613
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
614
615
        self.device_config = device_config
        self.cache_config = cache_config
616
        self.lora_config = lora_config
617
        self.load_config = load_config
618
        self.is_driver_worker = is_driver_worker
619
        self.prompt_adapter_config = prompt_adapter_config
620
        self.multimodal_config = multimodal_config
621
        self.return_hidden_states = return_hidden_states
622

623
        self.device = self.device_config.device
624
        self.pin_memory = is_pin_memory_available()
625

626
627
628
629
        self.kv_cache_dtype = kv_cache_dtype
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
630
631
632
633

        self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
            {} for _ in range(self.parallel_config.pipeline_parallel_size)
        ]
634
635
        self.graph_memory_pool: Optional[Tuple[
            int, int]] = None  # Set during graph capture.
Mor Zusman's avatar
Mor Zusman committed
636
637
638
639

        self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers(
            parallel_config)

640
        # When using CUDA graph, the input block tables must be padded to
641
        # max_seq_len_to_capture. However, creating the block table in
642
643
644
645
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
        # (max batch size to capture, max context len to capture / block size).
646
647
648
        self.graph_block_tables = np.zeros(
            (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
            dtype=np.int32)
649
650
        num_attn_heads = self.model_config.get_num_attention_heads(
            self.parallel_config)
651
        self.attn_backend = get_attn_backend(
652
            num_attn_heads,
653
654
655
656
657
658
            self.model_config.get_head_size(),
            self.model_config.get_num_kv_heads(self.parallel_config),
            self.model_config.get_sliding_window(),
            self.model_config.dtype,
            self.kv_cache_dtype,
            self.block_size,
659
        ) if num_attn_heads else None
660

661
662
663
        # Multi-modal data support
        self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
            .create_input_mapper(self.model_config)
664

665
        # Lazy initialization
666
        self.model: nn.Module  # Set after load_model
667
668
        # Set after load_model.
        self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
669
        self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
670

671
672
673
674
675
        self.flashinfer_decode_workspace_buffer = None
        self.flashinfer_decode_wrapper = None
        self.flashinfer_prefill_workspace_buffer = None
        self.flashinfer_prefill_wrapper = None

676
677
678
        set_cpu_offload_max_bytes(
            int(self.cache_config.cpu_offload_gb * 1024**3))

679
    def load_model(self) -> None:
680
        logger.info("Starting to load model %s...", self.model_config.model)
681
        with CudaMemoryProfiler() as m:
682
683
684
685
686
687
688
689
            self.model = get_model(model_config=self.model_config,
                                   device_config=self.device_config,
                                   load_config=self.load_config,
                                   lora_config=self.lora_config,
                                   multimodal_config=self.multimodal_config,
                                   parallel_config=self.parallel_config,
                                   scheduler_config=self.scheduler_config,
                                   cache_config=self.cache_config)
690
691

        self.model_memory_usage = m.consumed_memory
692
693
        logger.info("Loading model weights took %.4f GB",
                    self.model_memory_usage / float(2**30))
694
695

        if self.lora_config:
696
            assert supports_lora(self.model), "Model does not support LoRA"
697
698
699
            assert not supports_vision(
                self.model
            ), "To be tested: vision language model with LoRA settings."
700

701
702
            self.lora_manager = LRUCacheWorkerLoRAManager(
                self.scheduler_config.max_num_seqs,
703
704
705
706
707
708
709
710
711
                self.scheduler_config.max_num_batched_tokens,
                self.vocab_size,
                self.lora_config,
                self.device,
                self.model.embedding_modules,
                self.model.embedding_padding_modules,
                max_position_embeddings=self.model.config.
                max_position_embeddings,
            )
712
            self.model = self.lora_manager.create_lora_manager(self.model)
713

714
715
716
717
718
719
720
721
722
        if self.prompt_adapter_config:
            self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens, self.device,
                self.prompt_adapter_config)
            self.model = (
                self.prompt_adapter_manager.create_prompt_adapter_manager(
                    self.model))

723
        if self.kv_cache_dtype == "fp8" and is_hip():
724
725
726
            # Currently only ROCm accepts kv-cache scaling factors
            # via quantization_param_path and this will be deprecated
            # in the future.
727
728
            if self.model_config.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
729
730
731
732
733
734
                    warnings.warn(
                        "Loading kv cache scaling factor from JSON is "
                        "deprecated and will be removed. Please include "
                        "kv cache scaling factors in the model checkpoint.",
                        FutureWarning,
                        stacklevel=2)
735
736
                    self.model.load_kv_cache_scales(
                        self.model_config.quantization_param_path)
737
738
                    logger.info("Loaded KV cache scaling factors from %s",
                                self.model_config.quantization_param_path)
739
                else:
740
741
742
743
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__)
744
            else:
745
746
747
748
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!")
749

750
751
752
753
754
755
756
757
758
759
760
761
762
763
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        from vllm.model_executor.model_loader.loader import ShardedStateLoader
        ShardedStateLoader.save_model(
            self.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

764
765
766
767
768
769
770
771
772
773
    def save_tensorized_model(
        self,
        tensorizer_config: TensorizerConfig,
    ) -> None:
        from vllm.model_executor.model_loader.loader import TensorizerLoader
        TensorizerLoader.save_model(
            self.model,
            tensorizer_config=tensorizer_config,
        )

774
775
    def get_max_block_per_batch(self) -> int:
        block_size = self.block_size
776
        return (self.max_seq_len_to_capture + block_size - 1) // block_size
777

778
    def _prepare_model_input_tensors(
779
780
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
Mor Zusman's avatar
Mor Zusman committed
781
        finished_requests_ids: Optional[List[str]] = None
782
783
784
785
    ) -> TModelInputForGPU:
        """Helper method to prepare the model input based on a given sequence
        group. Prepares metadata needed for the base model forward pass but not
        metadata for possible additional steps, e.g., sampling.
786
787
788
789
790
791
792
793
794
795
796

        The API assumes seq_group_metadata_list is sorted by prefill -> decode.

        The result tensors and data structure also batches input in prefill
        -> decode order. For example,

        - input_tokens[:num_prefill_tokens] contains prefill tokens.
        - input_tokens[num_prefill_tokens:] contains decode tokens.

        If cuda graph is required, this API automatically pads inputs.
        """
797
798
        builder = ModelInputForGPUBuilder(weakref.proxy(self),
                                          finished_requests_ids)
799
        for seq_group_metadata in seq_group_metadata_list:
800
801
            builder.add_seq_group(seq_group_metadata)
        return builder.build()  # type: ignore
802

803
804
805
    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
806
        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
807
808
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs
809
810
811
812
        # This represents the maximum number of different requests
        # that will have unique loras, an therefore the max amount of memory
        # consumption create dummy lora request copies from the lora request
        # passed in, which contains a lora from the lora warmup path.
813
814
        dummy_lora_requests: List[LoRARequest] = []
        dummy_lora_requests_per_seq: List[LoRARequest] = []
815
        if self.lora_config:
816
            assert self.lora_manager is not None
817
818
819
820
821
822
            with self.lora_manager.dummy_lora_cache():
                for idx in range(self.lora_config.max_loras):
                    lora_id = idx + 1
                    dummy_lora_request = LoRARequest(
                        lora_name=f"warmup_{lora_id}",
                        lora_int_id=lora_id,
823
                        lora_path="/not/a/real/path",
824
825
826
827
828
829
830
831
                    )
                    self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                     rank=LORA_WARMUP_RANK)
                    dummy_lora_requests.append(dummy_lora_request)
                dummy_lora_requests_per_seq = [
                    dummy_lora_requests[idx % len(dummy_lora_requests)]
                    for idx in range(max_num_seqs)
                ]
832

833
834
835
        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
836
837
838
839
840
841
        # Additional GPU memory may be needed for vision encoding, which needs
        # to be accounted for when calculating the GPU blocks for
        # vLLM blocker manager.
        # To exercise the worst scenario for GPU memory consumption,
        # the number of seqs (batch_size) is chosen to maximize the number
        # of images processed.
842
843
        model_config = self.model_config

844
        if supports_vision(self.model):
845
846
847
848
849
850
851
852
853
854
855
856
857
            max_mm_tokens = MULTIMODAL_REGISTRY \
                .get_max_multimodal_tokens(model_config)
            max_num_seqs_orig = max_num_seqs
            max_num_seqs = min(max_num_seqs,
                               max_num_batched_tokens // max_mm_tokens)
            if max_num_seqs < 1:
                expr = (f"min({max_num_seqs_orig}, "
                        f"{max_num_batched_tokens} // {max_mm_tokens})")
                logger.warning(
                    "Computed max_num_seqs (%s) to be less than 1. "
                    "Setting it to the minimum value of 1.", expr)
                max_num_seqs = 1

858
        batch_size = 0
859
860
861
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
862
            batch_size += seq_len
863

864
865
            seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
                .dummy_data_for_profiling(model_config, seq_len)
866
867
868
869
870

            # Having more tokens is over-conservative but otherwise fine
            assert len(seq_data.prompt_token_ids) >= seq_len, (
                f"Expected at least {seq_len} dummy tokens for profiling, "
                f"but got: {len(seq_data.prompt_token_ids)}")
871

872
873
874
875
876
877
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
878
879
                lora_request=dummy_lora_requests_per_seq[group_id]
                if dummy_lora_requests_per_seq else None,
880
                multi_modal_data=dummy_multi_modal_data,
881
882
883
884
885
            )
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
886
        kv_caches = [None] * num_layers
Mor Zusman's avatar
Mor Zusman committed
887
888
889
        finished_requests_ids = [seq.request_id for seq in seqs]
        model_input = self.prepare_model_input(
            seqs, finished_requests_ids=finished_requests_ids)
890
891
892
893
894
895
896
        intermediate_tensors = None
        if not get_pp_group().is_first_rank:
            intermediate_tensors = self.model.make_empty_intermediate_tensors(
                batch_size=batch_size,
                dtype=self.model_config.dtype,
                device=self.device)
        self.execute_model(model_input, kv_caches, intermediate_tensors)
897
        torch.cuda.synchronize()
898
899
        return

900
    def remove_all_loras(self):
901
902
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
903
        self.lora_manager.remove_all_adapters()
904

905
    def set_active_loras(self, lora_requests: Set[LoRARequest],
906
907
908
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
909
        self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
910
911
912
913

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
914
        return self.lora_manager.add_adapter(lora_request)
915
916
917
918

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
919
        return self.lora_manager.remove_adapter(lora_id)
920
921
922
923

    def pin_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
924
        return self.lora_manager.pin_adapter(lora_id)
925
926
927
928

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
        return self.lora_manager.list_adapters()

    def remove_all_prompt_adapters(self):
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        self.prompt_adapter_manager.remove_all_adapters()

    def set_active_prompt_adapters(
            self, prompt_adapter_requests: Set[PromptAdapterRequest],
            prompt_adapter_mapping: PromptAdapterMapping) -> None:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        self.prompt_adapter_manager.set_active_adapters(
            prompt_adapter_requests, prompt_adapter_mapping)

    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        return self.prompt_adapter_manager.add_adapter(prompt_adapter_request)

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id)

    def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id)

    def list_prompt_adapters(self) -> Set[int]:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        return self.prompt_adapter_manager.list_adapters()
964

965
    @torch.inference_mode()
966
    def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
967
968
969
970
971
972
973
974
975
976
977
978
        """Cuda graph capture a model.

        Note that CUDA graph's performance gain is negligible if number
        of batched tokens are larger than 200. And since CUDA graph
        requires fixed sized tensors, supporting large/variable batch
        size requires high GPU memory overhead. Thus, vLLM only captures
        decoding requests. Mixed batch (chunked prefill + decoding) or
        prefill requests are not captured.

        Since it is used for decoding-only, it assumes there's only 1 token
        per sequence in the batch.
        """
979
980
981
982
983
        assert not self.model_config.enforce_eager
        logger.info("Capturing the model for CUDA graphs. This may lead to "
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
                    "use '--enforce-eager' in the CLI.")
984
985
        logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
                    "If you are running out of memory, consider decreasing "
986
987
988
                    "`gpu_memory_utilization` or enforcing eager mode. "
                    "You can also reduce the `max_num_seqs` as needed "
                    "to decrease memory usage.")
989
990
991
992
        start_time = time.perf_counter()

        # Prepare dummy inputs. These will be reused for all batch sizes.
        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
993
994
995
        input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
996
        slot_mapping.fill_(_PAD_SLOT_ID)
997
        seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
998
        block_tables = torch.from_numpy(self.graph_block_tables).cuda()
999
1000
1001
1002
1003
1004
        intermediate_inputs = None
        if not get_pp_group().is_first_rank:
            intermediate_inputs = self.model.make_empty_intermediate_tensors(
                batch_size=max_batch_size,
                dtype=self.model_config.dtype,
                device=self.device)
1005

1006
1007
        # Prepare buffer for outputs. These will be reused for all batch sizes.
        # It will be filled after the first graph capture.
1008
1009
1010
        hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [
            None
        ] * self.parallel_config.pipeline_parallel_size
1011

1012
1013
1014
1015
1016
1017
        graph_batch_size = _get_graph_batch_size(
            self.scheduler_config.max_num_seqs)
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
        if self.attn_backend.get_name() == "flashinfer":
            # For flashinfer, different batch sizes will share the
            # same workspace buffer.
            decode_workspace_buffer = \
            torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE,
                                                dtype=torch.uint8,
                                              device=self.device)
            indices_buffer = torch.empty(max_batch_size *
                                         self.cache_config.num_gpu_blocks,
                                         dtype=torch.int32,
                                         device=self.device)
            indptr_buffer = torch.empty(max_batch_size + 1,
                                        dtype=torch.int32,
                                        device=self.device)
            last_page_len_buffer = torch.empty(max_batch_size,
                                               dtype=torch.int32,
                                               device=self.device)

1036
        with graph_capture() as graph_capture_context:
1037
1038
            # NOTE: Capturing the largest batch size first may help reduce the
            # memory usage of CUDA graph.
1039
1040
1041
1042
            for virtual_engine in range(
                    self.parallel_config.pipeline_parallel_size):
                for batch_size in reversed(batch_size_capture_list):
                    if self.attn_backend.get_name() == "flashinfer":
1043
1044
1045
                        _indptr_buffer = indptr_buffer[:batch_size + 1]
                        _last_page_len_buffer = last_page_len_buffer[:
                                                                     batch_size]
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057

                        num_qo_heads = (
                            self.model_config.get_num_attention_heads(
                                self.parallel_config))
                        num_kv_heads = self.model_config.get_num_kv_heads(
                            self.parallel_config)
                        if num_qo_heads // num_kv_heads >= 4:
                            use_tensor_cores = True
                        else:
                            use_tensor_cores = False
                        decode_wrapper = \
                            CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
1058
1059
                            decode_workspace_buffer, _indptr_buffer,
                            indices_buffer, _last_page_len_buffer, "NHD",
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
                            use_tensor_cores)
                        kv_cache_dtype = get_kv_cache_torch_dtype(
                            self.kv_cache_dtype, self.model_config.dtype)

                        paged_kv_indptr_tensor_host = torch.arange(
                            0, batch_size + 1, dtype=torch.int32)
                        paged_kv_indices_tensor_host = torch.arange(
                            0, batch_size, dtype=torch.int32)
                        paged_kv_last_page_len_tensor_host = torch.full(
                            (batch_size, ), self.block_size, dtype=torch.int32)
                        query_start_loc_host = torch.arange(0,
                                                            batch_size + 1,
                                                            dtype=torch.int32)

                        attn_metadata = self.attn_backend.make_metadata(
                            num_prefills=0,
                            slot_mapping=slot_mapping[:batch_size],
                            num_prefill_tokens=0,
                            num_decode_tokens=batch_size,
                            max_prefill_seq_len=0,
                            block_tables=block_tables,
                            paged_kv_indptr=paged_kv_indptr_tensor_host,
                            paged_kv_indices=paged_kv_indices_tensor_host,
                            paged_kv_last_page_len=
                            paged_kv_last_page_len_tensor_host,
                            num_qo_heads=num_qo_heads,
                            num_kv_heads=num_kv_heads,
                            head_dim=self.model_config.get_head_size(),
                            page_size=self.block_size,
                            seq_start_loc=None,
                            query_start_loc=query_start_loc_host,
                            device=self.device,
                            data_type=kv_cache_dtype,
                            use_cuda_graph=True,
                            decode_wrapper=decode_wrapper,
                            prefill_wrapper=None)
                        attn_metadata.begin_forward()
1097
                    else:
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
                        attn_metadata = self.attn_backend.make_metadata(
                            num_prefills=0,
                            num_prefill_tokens=0,
                            num_decode_tokens=batch_size,
                            slot_mapping=slot_mapping[:batch_size],
                            seq_lens=None,
                            seq_lens_tensor=seq_lens[:batch_size],
                            max_query_len=None,
                            max_prefill_seq_len=0,
                            max_decode_seq_len=self.max_seq_len_to_capture,
                            query_start_loc=None,
                            seq_start_loc=None,
                            context_lens_tensor=None,
                            block_tables=block_tables[:batch_size],
                            use_cuda_graph=True,
                        )

                    if self.lora_config:
                        lora_mapping = LoRAMapping(
                            [0] * batch_size,
                            [0] * batch_size,
                        )
                        self.set_active_loras(set(), lora_mapping)

1122
1123
1124
1125
1126
1127
1128
1129
                    if self.prompt_adapter_config:
                        prompt_adapter_mapping = PromptAdapterMapping(
                            [-1] * batch_size,
                            [-1] * batch_size,
                        )
                        self.set_active_prompt_adapters(
                            set(), prompt_adapter_mapping)

1130
1131
1132
1133
                    graph_runner = CUDAGraphRunner(
                        self.model, self.attn_backend.get_name())

                    if self.attn_backend.get_name() == "flashinfer":
1134
                        graph_runner.flashinfer_indptr_buffer = _indptr_buffer
1135
1136
                        graph_runner.flashinfer_indices_buffer = indices_buffer
                        graph_runner.flashinfer_last_page_len_buffer = \
1137
                            _last_page_len_buffer
1138
1139
1140
1141
1142
                        graph_runner.flashinfer_decode_workspace_buffer = \
                                decode_workspace_buffer
                        graph_runner.flashinfer_decode_wrapper = \
                            decode_wrapper

Mor Zusman's avatar
Mor Zusman committed
1143
1144
                    capture_inputs = {
                        "input_ids":
1145
                        input_tokens[:batch_size],
Mor Zusman's avatar
Mor Zusman committed
1146
                        "positions":
1147
                        input_positions[:batch_size],
Mor Zusman's avatar
Mor Zusman committed
1148
                        "hidden_or_intermediate_states":
1149
1150
1151
1152
1153
                        hidden_or_intermediate_states[
                            virtual_engine]  # type: ignore
                        [:batch_size]
                        if hidden_or_intermediate_states[virtual_engine]
                        is not None else None,
Mor Zusman's avatar
Mor Zusman committed
1154
                        "intermediate_inputs":
1155
1156
                        intermediate_inputs[:batch_size]
                        if intermediate_inputs is not None else None,
Mor Zusman's avatar
Mor Zusman committed
1157
                        "kv_caches":
1158
                        kv_caches[virtual_engine],
Mor Zusman's avatar
Mor Zusman committed
1159
                        "attn_metadata":
1160
                        attn_metadata,
Mor Zusman's avatar
Mor Zusman committed
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
                        "memory_pool":
                        self.graph_memory_pool,
                        "stream":
                        graph_capture_context.stream
                    }
                    if self.has_seqlen_agnostic:
                        # Only used by Mamba-based models CUDA graph atm (Jamba)
                        capture_inputs.update({
                            "seqlen_agnostic_capture_inputs":
                            self.model.get_seqlen_agnostic_capture_inputs(
                                batch_size)
                        })
                    graph_runner.capture(**capture_inputs)
1174
1175
1176
                    self.graph_memory_pool = graph_runner.graph.pool()
                    self.graph_runners[virtual_engine][batch_size] = (
                        graph_runner)
1177
1178
1179
1180

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        # This usually takes < 10 seconds.
1181
        logger.info("Graph capturing finished in %.0f secs.", elapsed_time)
1182

1183
1184
1185
1186
    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

1187

1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
    """
    GPU model runner with sampling step.
    """
    _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
        ModelInputForGPUWithSamplingMetadata)

    def make_model_input_from_broadcasted_tensor_dict(
        self,
        tensor_dict: Dict[str, Any],
    ) -> ModelInputForGPUWithSamplingMetadata:
1199
        model_input = \
1200
1201
1202
            ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
                tensor_dict,
                attn_backend=self.attn_backend,
1203
1204
            )
        return model_input
1205
1206
1207
1208

    def prepare_model_input(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
1209
        virtual_engine: int = 0,
Mor Zusman's avatar
Mor Zusman committed
1210
        finished_requests_ids: Optional[List[str]] = None
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
    ) -> ModelInputForGPUWithSamplingMetadata:
        """Prepare the model input based on a given sequence group, including
        metadata for the sampling step.

        The API assumes seq_group_metadata_list is sorted by prefill -> decode.

        The result tensors and data structure also batches input in prefill
        -> decode order. For example,

        - input_tokens[:num_prefill_tokens] contains prefill tokens.
        - input_tokens[num_prefill_tokens:] contains decode tokens.

        If cuda graph is required, this API automatically pads inputs.
        """
        model_input = self._prepare_model_input_tensors(
Mor Zusman's avatar
Mor Zusman committed
1226
            seq_group_metadata_list, finished_requests_ids)
1227
1228
1229
1230
1231
1232
1233
1234
1235
        sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
                                                     model_input.seq_lens,
                                                     model_input.query_lens,
                                                     self.device,
                                                     self.pin_memory)
        is_prompt = (seq_group_metadata_list[0].is_prompt
                     if seq_group_metadata_list else None)
        return dataclasses.replace(model_input,
                                   sampling_metadata=sampling_metadata,
1236
1237
                                   is_prompt=is_prompt,
                                   virtual_engine=virtual_engine)
1238
1239
1240
1241
1242
1243

    @torch.inference_mode()
    def execute_model(
        self,
        model_input: ModelInputForGPUWithSamplingMetadata,
        kv_caches: List[torch.Tensor],
1244
        intermediate_tensors: Optional[IntermediateTensors] = None,
1245
        num_steps: int = 1,
1246
    ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
1247
1248
1249
        if num_steps > 1:
            raise ValueError("num_steps > 1 is not supported in ModelRunner")

1250
1251
1252
1253
1254
1255
        if self.lora_config:
            assert model_input.lora_requests is not None
            assert model_input.lora_mapping is not None
            self.set_active_loras(model_input.lora_requests,
                                  model_input.lora_mapping)

1256
1257
1258
1259
1260
1261
1262
        if self.prompt_adapter_config:
            assert model_input.prompt_adapter_requests is not None
            assert model_input.prompt_adapter_mapping is not None
            self.set_active_prompt_adapters(
                model_input.prompt_adapter_requests,
                model_input.prompt_adapter_mapping)

1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
        if self.attn_backend.get_name() == "flashinfer":
            assert model_input.attn_metadata is not None
            assert model_input.input_tokens is not None
            if self.flashinfer_decode_workspace_buffer is None:
                self.flashinfer_decode_workspace_buffer = torch.empty(
                    FLASHINFER_WORKSPACE_BUFFER_SIZE,
                    dtype=torch.uint8,
                    device=self.device)
                self.flashinfer_decode_wrapper = \
                    BatchDecodeWithPagedKVCacheWrapper(
                    self.flashinfer_decode_workspace_buffer, "NHD")
                self.flashinfer_prefill_workspace_buffer = torch.empty(
                    FLASHINFER_WORKSPACE_BUFFER_SIZE,
                    dtype=torch.uint8,
                    device=self.device)
                self.flashinfer_prefill_wrapper = \
                    BatchPrefillWithPagedKVCacheWrapper(
                    self.flashinfer_prefill_workspace_buffer, "NHD")

            model_input.attn_metadata.prefill_wrapper = \
                self.flashinfer_prefill_wrapper
            if model_input.attn_metadata.use_cuda_graph:
                batch_size = model_input.input_tokens.shape[0]
                model_input.attn_metadata.decode_wrapper = self.graph_runners[
1287
1288
                    model_input.
                    virtual_engine][batch_size].flashinfer_decode_wrapper
1289
1290
1291
1292
1293
            else:
                model_input.attn_metadata.decode_wrapper = \
                    self.flashinfer_decode_wrapper
            model_input.attn_metadata.begin_forward()

1294
1295
1296
1297
        # Currently cuda graph is only supported by the decode phase.
        assert model_input.attn_metadata is not None
        prefill_meta = model_input.attn_metadata.prefill_metadata
        decode_meta = model_input.attn_metadata.decode_metadata
1298
1299
1300
        # TODO(andoorve): We can remove this once all
        # virtual engines share the same kv cache.
        virtual_engine = model_input.virtual_engine
1301
1302
1303
        if prefill_meta is None and decode_meta.use_cuda_graph:
            assert model_input.input_tokens is not None
            graph_batch_size = model_input.input_tokens.shape[0]
1304
1305
            model_executable = self.graph_runners[virtual_engine][
                graph_batch_size]
1306
1307
1308
1309
        else:
            model_executable = self.model

        multi_modal_kwargs = model_input.multi_modal_kwargs or {}
Mor Zusman's avatar
Mor Zusman committed
1310
1311
1312
1313
        seqlen_agnostic_kwargs = {
            "finished_requests_ids": model_input.finished_requests_ids,
            "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
        } if self.has_seqlen_agnostic else {}
1314
        hidden_or_intermediate_states = model_executable(
1315
1316
1317
1318
            input_ids=model_input.input_tokens,
            positions=model_input.input_positions,
            kv_caches=kv_caches,
            attn_metadata=model_input.attn_metadata,
1319
            intermediate_tensors=intermediate_tensors,
1320
            **multi_modal_kwargs,
Mor Zusman's avatar
Mor Zusman committed
1321
            **seqlen_agnostic_kwargs)
1322

1323
1324
1325
1326
1327
        # Compute the logits in the last pipeline stage.
        if not get_pp_group().is_last_rank:
            return hidden_or_intermediate_states

        logits = self.model.compute_logits(hidden_or_intermediate_states,
1328
1329
1330
                                           model_input.sampling_metadata)

        if not self.is_driver_worker:
1331
            return []
1332
1333
1334
1335
1336
1337
1338
1339
1340

        # Sample the next token.
        output: SamplerOutput = self.model.sample(
            logits=logits,
            sampling_metadata=model_input.sampling_metadata,
        )

        if self.return_hidden_states:
            # we only need to pass hidden states of most recent token
1341
1342
            assert model_input.sampling_metadata is not None
            indices = model_input.sampling_metadata.selected_token_indices
1343
            if model_input.is_prompt:
1344
1345
                hidden_states = hidden_or_intermediate_states.index_select(
                    0, indices)
1346
            elif decode_meta.use_cuda_graph:
1347
1348
1349
                hidden_states = hidden_or_intermediate_states[:len(indices)]
            else:
                hidden_states = hidden_or_intermediate_states
1350

1351
1352
            output.hidden_states = hidden_states

1353
        return [output]
1354
1355


1356
1357
class CUDAGraphRunner:

1358
    def __init__(self, model: nn.Module, backend_name: str):
1359
        self.model = model
1360
1361
        self.backend_name = backend_name

1362
1363
1364
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

1365
1366
        self._graph: Optional[torch.cuda.CUDAGraph] = None

1367
1368
1369
1370
1371
1372
1373
        self.flashinfer_decode_workspace_buffer: Optional[torch.Tensor] = None
        self.flashinfer_indptr_buffer: Optional[torch.Tensor] = None
        self.flashinfer_indices_buffer: Optional[torch.Tensor] = None
        self.flashinfer_last_page_len_buffer: Optional[torch.Tensor] = None
        self.flashinfer_decode_wrapper: Optional[
            CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None

1374
1375
1376
1377
1378
    @property
    def graph(self):
        assert self._graph is not None
        return self._graph

1379
1380
1381
1382
    def capture(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1383
1384
1385
        hidden_or_intermediate_states: Optional[Union[IntermediateTensors,
                                                      torch.Tensor]],
        intermediate_inputs: Optional[IntermediateTensors],
1386
1387
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1388
1389
        memory_pool: Optional[Tuple[int, int]],
        stream: torch.cuda.Stream,
1390
        **kwargs,
1391
    ) -> Union[torch.Tensor, IntermediateTensors]:
1392
        assert self._graph is None
1393
        # Run the model a few times without capturing the graph.
1394
1395
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
1396
1397
1398
1399
1400
1401
1402
        # Note one iteration is not enough for torch.jit.script
        for _ in range(_NUM_WARMUP_ITERS):
            self.model(
                input_ids,
                positions,
                kv_caches,
                attn_metadata,
1403
                intermediate_inputs,
1404
1405
                **kwargs,
            )
1406
1407
1408
1409
1410
        torch.cuda.synchronize()

        # Capture the graph.
        self._graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
1411
            output_hidden_or_intermediate_states = self.model(
1412
1413
1414
                input_ids,
                positions,
                kv_caches,
1415
                attn_metadata,
1416
                intermediate_inputs,
1417
                **kwargs,
1418
            )
1419
1420
1421
1422
1423
1424
1425
1426
            if hidden_or_intermediate_states is not None:
                if get_pp_group().is_last_rank:
                    hidden_or_intermediate_states.copy_(
                        output_hidden_or_intermediate_states)
                else:
                    for key in hidden_or_intermediate_states.tensors:
                        hidden_or_intermediate_states[key].copy_(
                            output_hidden_or_intermediate_states[key])
1427
            else:
1428
1429
1430
1431
                hidden_or_intermediate_states = (
                    output_hidden_or_intermediate_states)

            del output_hidden_or_intermediate_states
1432
1433
1434
            # make sure `output_hidden_states` is deleted
            # in the graph's memory pool
            gc.collect()
1435
1436
1437
        torch.cuda.synchronize()

        # Save the input and output buffers.
1438
1439
1440
1441
1442
1443
        if self.backend_name == "flashinfer":
            self.input_buffers = {
                "input_ids": input_ids,
                "positions": positions,
                "kv_caches": kv_caches,
                "slot_mapping": attn_metadata.slot_mapping,
Mor Zusman's avatar
Mor Zusman committed
1444
                **kwargs,
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
            }
        else:
            self.input_buffers = {
                "input_ids": input_ids,
                "positions": positions,
                "kv_caches": kv_caches,
                "slot_mapping": attn_metadata.slot_mapping,
                "seq_lens_tensor":
                attn_metadata.decode_metadata.seq_lens_tensor,
                "block_tables": attn_metadata.decode_metadata.block_tables,
Mor Zusman's avatar
Mor Zusman committed
1455
                **kwargs,
1456
            }
1457
1458
1459
1460
1461
1462
1463
1464
1465
        if intermediate_inputs is not None:
            self.input_buffers.update(intermediate_inputs.tensors)
        if get_pp_group().is_last_rank:
            self.output_buffers = {
                "hidden_states": hidden_or_intermediate_states
            }
        else:
            self.output_buffers = hidden_or_intermediate_states
        return hidden_or_intermediate_states
1466
1467
1468
1469
1470

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1471
1472
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1473
        intermediate_tensors: Optional[IntermediateTensors],
1474
        **kwargs,
1475
1476
1477
1478
1479
    ) -> torch.Tensor:
        # KV caches are fixed tensors, so we don't need to copy them.
        del kv_caches

        # Copy the input tensors to the input buffers.
1480
1481
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
        self.input_buffers["positions"].copy_(positions, non_blocking=True)
1482
        self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
1483
                                                 non_blocking=True)
1484
1485
1486
1487
1488
1489
        if self.backend_name != "flashinfer":
            self.input_buffers["seq_lens_tensor"].copy_(
                attn_metadata.decode_metadata.seq_lens_tensor,
                non_blocking=True)
            self.input_buffers["block_tables"].copy_(
                attn_metadata.decode_metadata.block_tables, non_blocking=True)
Mor Zusman's avatar
Mor Zusman committed
1490
1491
1492
        if "seqlen_agnostic_capture_inputs" in self.input_buffers:
            self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
                                                      **kwargs)
1493
1494
1495
1496
        if intermediate_tensors is not None:
            for key in intermediate_tensors.tensors:
                self.input_buffers[key].copy_(intermediate_tensors[key],
                                              non_blocking=True)
1497
1498
        # Run the graph.
        self.graph.replay()
Mor Zusman's avatar
Mor Zusman committed
1499
1500
1501
        if "seqlen_agnostic_capture_inputs" in self.input_buffers:
            self.model.copy_outputs_after_cuda_graphs(self.input_buffers,
                                                      **kwargs)
1502
        # Return the output tensor.
1503
1504
1505
1506
        if get_pp_group().is_last_rank:
            return self.output_buffers["hidden_states"]

        return self.output_buffers
1507
1508
1509
1510

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

1511

1512
def _get_graph_batch_size(batch_size: int) -> int:
1513
1514
1515
1516
1517
    """Returns the padded batch size given actual batch size.

    Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
    2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
    """
1518
1519
1520
1521
1522
    if batch_size <= 2:
        return batch_size
    elif batch_size <= 4:
        return 4
    else:
1523
1524
        return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
                _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)