model_runner.py 61.8 KB
Newer Older
1
import dataclasses
2
import gc
3
import time
4
import warnings
5
import weakref
6
from collections import defaultdict
7
8
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
                    Tuple, Type, TypeVar, Union)
9

10
import numpy as np
11
import torch
12
import torch.distributed
13
import torch.nn as nn
14

15
16
17
18
try:
    from flashinfer import BatchDecodeWithPagedKVCacheWrapper
    from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
    from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
19
    FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
20
21
22
23
24
25
except ImportError:
    BatchDecodeWithPagedKVCacheWrapper = None
    CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
    BatchPrefillWithPagedKVCacheWrapper = None
    FLASHINFER_WORKSPACE_BUFFER_SIZE = 0

26
from vllm.attention import AttentionMetadata, get_attn_backend
27
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
28
                         ModelConfig, MultiModalConfig, ParallelConfig,
29
                         PromptAdapterConfig, SchedulerConfig)
30
from vllm.distributed import get_pp_group
31
from vllm.distributed.parallel_state import graph_capture
32
from vllm.inputs import INPUT_REGISTRY
33
from vllm.logger import init_logger
34
35
36
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
37
from vllm.model_executor import SamplingMetadata
38
from vllm.model_executor.model_loader import get_model
39
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
40
41
from vllm.model_executor.models.interfaces import (supports_lora,
                                                   supports_vision)
42
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
43
44
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
                             MultiModalInputs)
45
46
47
48
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import (
    LRUCacheWorkerPromptAdapterManager)
49
from vllm.sampling_params import SamplingParams
50
51
from vllm.sequence import (IntermediateTensors, SamplerOutput,
                           SequenceGroupMetadata)
52
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
53
                        is_pin_memory_available)
54
from vllm.worker.model_runner_base import (
55
    ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
56
57
58
59
60
61
62
    _add_attn_metadata_broadcastable_dict,
    _add_sampling_metadata_broadcastable_dict,
    _init_attn_metadata_from_tensor_dict,
    _init_sampling_metadata_from_tensor_dict)

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
63
64
65
66

logger = init_logger(__name__)

_PAD_SLOT_ID = -1
67
LORA_WARMUP_RANK = 8
68
69
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
70
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
71
72
73
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
    _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
74
_NUM_WARMUP_ITERS = 2
75

76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")


@dataclasses.dataclass(frozen=True)
class ModelInputForGPU(ModelRunnerInputBase):
    """
    This base class contains metadata needed for the base model forward pass
    but not metadata for possible additional steps, e.g., sampling. Model
    runners that run additional steps should subclass this method to add
    additional fields.
    """
    input_tokens: Optional[torch.Tensor] = None
    input_positions: Optional[torch.Tensor] = None
    seq_lens: Optional[List[int]] = None
    query_lens: Optional[List[int]] = None
    lora_mapping: Optional["LoRAMapping"] = None
    lora_requests: Optional[Set[LoRARequest]] = None
    attn_metadata: Optional["AttentionMetadata"] = None
94
95
    prompt_adapter_mapping: Optional[PromptAdapterMapping] = None
    prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None
96
    multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
Mor Zusman's avatar
Mor Zusman committed
97
98
    request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
    finished_requests_ids: Optional[List[str]] = None
99
    virtual_engine: int = 0
100
101
102
103
104
105
106
107

    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
            "lora_requests": self.lora_requests,
            "lora_mapping": self.lora_mapping,
            "multi_modal_kwargs": self.multi_modal_kwargs,
108
109
            "prompt_adapter_mapping": self.prompt_adapter_mapping,
            "prompt_adapter_requests": self.prompt_adapter_requests,
110
            "virtual_engine": self.virtual_engine,
Mor Zusman's avatar
Mor Zusman committed
111
112
            "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
            "finished_requests_ids": self.finished_requests_ids,
113
114
115
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        return tensor_dict
116
117

    @classmethod
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    def from_broadcasted_tensor_dict(
        cls: Type[TModelInputForGPU],
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> TModelInputForGPU:
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


@dataclasses.dataclass(frozen=True)
class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
    """
    Used by the ModelRunner.
    """
    sampling_metadata: Optional["SamplingMetadata"] = None
    # Used for speculative decoding. We do not broadcast it because it is only
    # used by the driver worker.
    is_prompt: Optional[bool] = None

    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
            "lora_requests": self.lora_requests,
            "lora_mapping": self.lora_mapping,
            "multi_modal_kwargs": self.multi_modal_kwargs,
146
147
            "prompt_adapter_mapping": self.prompt_adapter_mapping,
            "prompt_adapter_requests": self.prompt_adapter_requests,
148
            "virtual_engine": self.virtual_engine,
Mor Zusman's avatar
Mor Zusman committed
149
150
            "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
            "finished_requests_ids": self.finished_requests_ids,
151
152
153
154
155
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        _add_sampling_metadata_broadcastable_dict(tensor_dict,
                                                  self.sampling_metadata)
        return tensor_dict
156

157
158
159
160
161
162
163
164
165
166
167
168
169
    @classmethod
    def from_broadcasted_tensor_dict(
        cls,
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "ModelInputForGPUWithSamplingMetadata":
        tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
    """TBA"""

    def __init__(self,
                 runner: "GPUModelRunnerBase",
                 finished_requests_ids: Optional[List[str]] = None):
        super().__init__()
        self.runner = runner
        self.model_input_cls = self.runner._model_input_cls
        self.attn_backend = self.runner.attn_backend
        self.scheduler_config = self.runner.scheduler_config
        self.sliding_window = self.runner.sliding_window
        self.block_size = self.runner.block_size
        self.enable_lora = self.runner.lora_config is not None
        self.enable_prompt_adapter = (self.runner.prompt_adapter_config
                                      is not None)
        self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
        self.finished_requests_ids = finished_requests_ids
        self.decode_only = True

        # Common inputs.
        self.input_tokens: List[int] = []
        self.input_positions: List[int] = []
        self.seq_lens: List[int] = []
        self.query_lens: List[int] = []
        self.max_decode_seq_len: int = 0
        self.request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list)

        # LoRA inputs.
        self.lora_index_mapping: List[int] = []
        self.lora_prompt_mapping: List[int] = []
        self.lora_requests: Set[LoRARequest] = set()

        # Prompt adapter inputs.
        self.prompt_adapter_index_mapping: List[int] = []
        self.prompt_adapter_prompt_mapping: List[int] = []
        self.prompt_adapter_requests: Set[PromptAdapterRequest] = set()

        # Multi-modal inputs.
        self.multi_modal_inputs_list: List[MultiModalInputs] = []

        # Attention metadata inputs.
        self.attn_metadata_builder = self.attn_backend.make_metadata_builder(
            self)

        # Engine/Model configurations.
        self.chunked_prefill_enabled = (
            self.scheduler_config is not None
            and self.scheduler_config.chunked_prefill_enabled)
        if self.sliding_window is not None:
            self.sliding_window_blocks = (
                self.sliding_window + self.block_size - 1) // self.block_size
            self.block_aligned_sliding_window = \
                self.sliding_window_blocks * self.block_size

    def _compute_len_for_sliding_window(self, seq_len: int):
        curr_sliding_window_blocks = 0
        sliding_seq_len = seq_len

        # TODO(sang): This is a hack to make sliding window work with
        # paged attn. We can remove it if we make paged attn kernel
        # to properly handle slinding window attn.
        if self.sliding_window is not None:
            curr_sliding_window_blocks = self.sliding_window_blocks
            if self.scheduler_config.use_v2_block_manager:
                # number of elements in last block
                suff_len = seq_len % self.block_size
                sliding_seq_len = min(
                    seq_len, self.block_aligned_sliding_window + suff_len)
                if suff_len > 0:
                    curr_sliding_window_blocks += 1
            else:
                sliding_seq_len = min(seq_len, self.sliding_window)
        return curr_sliding_window_blocks, sliding_seq_len

    def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
        seq_ids = list(seq_group_metadata.seq_data.keys())
        n_seqs = len(seq_ids)
        is_prompt = seq_group_metadata.is_prompt
        token_chunk_size = seq_group_metadata.token_chunk_size

        if is_prompt:
            assert n_seqs == 1
            self.decode_only = False

        # Mapping from request IDs to sequence IDs. Used for Jamba models
        # that manages the cache by itself.
        self.request_ids_to_seq_ids[seq_group_metadata.request_id] = []
        # The number of input tokens in each sequence.
        token_lens: List[int] = []
        # The number of tokens that are already computed.
        context_lens: List[int] = []
        # The current sliding window block for each sequence.
        curr_sliding_window_blocks: List[int] = []
        # The original sequence length (before applying sliding window)
        # for each sequence.
        orig_seq_lens: List[int] = []
        # The sequence length (may be capped to the sliding window).
        curr_seq_lens: List[int] = []
        for seq_id in seq_ids:
            seq_data = seq_group_metadata.seq_data[seq_id]
            self.request_ids_to_seq_ids[seq_group_metadata.request_id].append(
                seq_id)
            computed_block_nums = seq_group_metadata.computed_block_nums

            # Check if hit prefix cache (i.e., some blocks are already computed)
            # Note that prefix caching does not support sliding window.
            prefix_cache_hit = (computed_block_nums is not None
                                and len(computed_block_nums) > 0
                                and self.sliding_window is None and is_prompt)
            if self.chunked_prefill_enabled and prefix_cache_hit:
                raise RuntimeError(
                    "chunked prefill cannot be used with prefix caching now.")

            # Compute context length (the number of tokens that are
            # already computed) and sequence length (total number of tokens).
            seq_len = seq_data.get_len()
            if is_prompt:
                context_len = seq_data.get_num_computed_tokens()
            else:
                # get_num_computed_tokens is incorrect for spec decoding.
                # So, we should have a special logic here.
                # TODO(sang): Fix it.
                context_len = seq_len - 1
            seq_len = min(seq_len, context_len + token_chunk_size)

            # Compute tokens.
            if is_prompt:
                tokens = seq_data.get_token_ids()[context_len:seq_len]
            else:
                # Optimization. get_token_ids requires the entire copy of
                # tokens.
                tokens = [seq_data.get_last_token_id()]
            if prefix_cache_hit:
                assert computed_block_nums is not None
                context_len = len(computed_block_nums) * self.block_size
                tokens = tokens[context_len:]

            # These are seq_len/context_len capped to the sliding window.
            # They are passed to decode kernel.
            # We still need original seq_len/context_len to compute slot
            # mapping (and input position) below.
            if is_prompt:
                curr_sliding_window_block = 0
                sliding_seq_len = seq_len
                query_len = seq_len - context_len
            else:
                curr_sliding_window_block, sliding_seq_len = (
                    self._compute_len_for_sliding_window(seq_len))
                query_len = 1

            self.seq_lens.append(sliding_seq_len)
            if not is_prompt:
                self.max_decode_seq_len = max(self.max_decode_seq_len,
                                              sliding_seq_len)
            self.query_lens.append(query_len)
            self.input_tokens.extend(tokens)
            self.input_positions.extend(list(range(context_len, seq_len)))

            # Intermediate data of the current sequence group for
            # the attention metadata.
            token_lens.append(len(tokens))
            context_lens.append(context_len)
            curr_seq_lens.append(sliding_seq_len)
            curr_sliding_window_blocks.append(curr_sliding_window_block)
            orig_seq_lens.append(seq_len)

        # Update attention metadata. Note that input builder attributes
        # (self.xxx) include all added sequences, so we need to slice
        # the last n_seqs sequences.
        self.attn_metadata_builder.add_seq_group(
            seq_group_metadata, token_lens, orig_seq_lens, curr_seq_lens,
            self.query_lens[-n_seqs:], context_lens,
            curr_sliding_window_blocks, prefix_cache_hit,
            self.chunked_prefill_enabled)

        # LoRA data.
        if self.enable_lora:
            lora_id = seq_group_metadata.lora_int_id
            for query_len in self.query_lens[-n_seqs:]:
                if lora_id > 0:
                    self.lora_requests.add(seq_group_metadata.lora_request)
                self.lora_index_mapping += [lora_id] * query_len
                self.lora_prompt_mapping.extend(
                    [lora_id] *
                    (query_len if seq_group_metadata.sampling_params
                     and seq_group_metadata.sampling_params.prompt_logprobs
                     is not None else 1))

        # Prompt adapter data. Note that when is_prompt=True,
        # we expect only one sequence in the group.
        if self.enable_prompt_adapter:
            prompt_adapter_id = seq_group_metadata.prompt_adapter_id
            if prompt_adapter_id > 0 and is_prompt:
                query_len = self.query_lens[-1]
                self.prompt_adapter_requests.add(
                    seq_group_metadata.prompt_adapter_request)

                num_tokens = seq_group_metadata.\
                    prompt_adapter_num_virtual_tokens
                pm = [prompt_adapter_id
                      ] * num_tokens + [0] * (query_len - num_tokens)
                self.prompt_adapter_index_mapping += pm
                self.prompt_adapter_prompt_mapping.extend(
                    [prompt_adapter_id] *
                    (query_len if seq_group_metadata.sampling_params
                     and seq_group_metadata.sampling_params.prompt_logprobs
                     else 1))

        # Multi-modal data.
        mm_data = seq_group_metadata.multi_modal_data
        if mm_data:
            mm_kwargs = self.multi_modal_input_mapper(mm_data)
            self.multi_modal_inputs_list.append(mm_kwargs)

    def build(self) -> ModelInputForGPU:
        if not self.input_tokens:
            return self.model_input_cls()

        batch_size = len(self.input_tokens)
        use_captured_graph = (
            self.decode_only and not self.runner.model_config.enforce_eager
            and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
            and self.max_decode_seq_len <= self.runner.max_seq_len_to_capture)

        # If cuda graph can be used, pad tensors accordingly.
        # See `capture_model` API for more details.
        # vLLM uses cuda graph only for decoding requests.
        cuda_graph_pad_size = -1
        if use_captured_graph:
            graph_batch_size = _get_graph_batch_size(batch_size)
            assert graph_batch_size >= batch_size
            cuda_graph_pad_size = graph_batch_size - batch_size
            batch_size = graph_batch_size

        # Tokens and positions.
        self.input_tokens.extend([0] * cuda_graph_pad_size)
        self.input_positions.extend([0] * cuda_graph_pad_size)
        input_tokens_tensor = torch.tensor(self.input_tokens,
                                           dtype=torch.long,
                                           device=self.runner.device)
        input_positions_tensor = torch.tensor(self.input_positions,
                                              dtype=torch.long,
                                              device=self.runner.device)

        # Sequence and query lengths.
        self.seq_lens.extend([1] * cuda_graph_pad_size)

        # Attention metadata.
        attn_metadata = self.attn_metadata_builder.build(
            self.runner, self.seq_lens, self.query_lens, cuda_graph_pad_size,
            batch_size)

        # LoRA data.
        if self.enable_lora:
            self.lora_index_mapping.extend([0] * cuda_graph_pad_size)
            lora_mapping = LoRAMapping(
                self.lora_index_mapping,
                self.lora_prompt_mapping,
            )
        else:
            lora_mapping = None

        # Prompt adapter data.
        if self.enable_prompt_adapter:
            self.prompt_adapter_index_mapping.extend([0] * cuda_graph_pad_size)
            prompt_adapter_mapping = PromptAdapterMapping(
                self.prompt_adapter_index_mapping,
                self.prompt_adapter_prompt_mapping,
            )
        else:
            prompt_adapter_mapping = None

        # Multi-modal data.
        multi_modal_kwargs = MultiModalInputs.batch(
            self.multi_modal_inputs_list, device=self.runner.device)

        return self.model_input_cls(
            input_tokens=input_tokens_tensor,
            input_positions=input_positions_tensor,
            attn_metadata=attn_metadata,
            seq_lens=self.seq_lens,
            query_lens=self.query_lens,
            lora_mapping=lora_mapping,
            lora_requests=self.lora_requests,
            multi_modal_kwargs=multi_modal_kwargs,
            request_ids_to_seq_ids=self.request_ids_to_seq_ids,
            finished_requests_ids=self.finished_requests_ids,
            prompt_adapter_mapping=prompt_adapter_mapping,
            prompt_adapter_requests=self.prompt_adapter_requests)


462
463
464
465
466
class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
    """
    Helper class for shared methods between GPU model runners.
    """
    _model_input_cls: Type[TModelInputForGPU]
467
468
469
470
471
472

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
473
        device_config: DeviceConfig,
474
        cache_config: CacheConfig,
475
        load_config: LoadConfig,
476
        lora_config: Optional[LoRAConfig],
477
        kv_cache_dtype: Optional[str] = "auto",
478
        is_driver_worker: bool = False,
479
        prompt_adapter_config: Optional[PromptAdapterConfig] = None,
480
        multimodal_config: Optional[MultiModalConfig] = None,
481
        return_hidden_states: bool = False,
482
483
484
485
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
486
487
        self.device_config = device_config
        self.cache_config = cache_config
488
        self.lora_config = lora_config
489
        self.load_config = load_config
490
        self.is_driver_worker = is_driver_worker
491
        self.prompt_adapter_config = prompt_adapter_config
492
        self.multimodal_config = multimodal_config
493
        self.return_hidden_states = return_hidden_states
494

495
        self.device = self.device_config.device
496
        self.pin_memory = is_pin_memory_available()
497

498
499
500
501
        self.kv_cache_dtype = kv_cache_dtype
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
502
503
504
505

        self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
            {} for _ in range(self.parallel_config.pipeline_parallel_size)
        ]
506
507
        self.graph_memory_pool: Optional[Tuple[
            int, int]] = None  # Set during graph capture.
Mor Zusman's avatar
Mor Zusman committed
508
509
510
511

        self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers(
            parallel_config)

512
        # When using CUDA graph, the input block tables must be padded to
513
        # max_seq_len_to_capture. However, creating the block table in
514
515
516
517
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
        # (max batch size to capture, max context len to capture / block size).
518
519
520
        self.graph_block_tables = np.zeros(
            (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
            dtype=np.int32)
521
522
        num_attn_heads = self.model_config.get_num_attention_heads(
            self.parallel_config)
523
        self.attn_backend = get_attn_backend(
524
            num_attn_heads,
525
526
527
528
529
530
            self.model_config.get_head_size(),
            self.model_config.get_num_kv_heads(self.parallel_config),
            self.model_config.get_sliding_window(),
            self.model_config.dtype,
            self.kv_cache_dtype,
            self.block_size,
531
        ) if num_attn_heads else None
532

533
534
535
        # Multi-modal data support
        self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
            .create_input_mapper(self.model_config)
536

537
        # Lazy initialization
538
        self.model: nn.Module  # Set after load_model
539
540
        # Set after load_model.
        self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
541
        self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
542

543
544
545
546
547
        self.flashinfer_decode_workspace_buffer = None
        self.flashinfer_decode_wrapper = None
        self.flashinfer_prefill_workspace_buffer = None
        self.flashinfer_prefill_wrapper = None

548
549
550
        set_cpu_offload_max_bytes(
            int(self.cache_config.cpu_offload_gb * 1024**3))

551
    def load_model(self) -> None:
552
        with CudaMemoryProfiler() as m:
553
554
555
556
557
558
559
560
            self.model = get_model(model_config=self.model_config,
                                   device_config=self.device_config,
                                   load_config=self.load_config,
                                   lora_config=self.lora_config,
                                   multimodal_config=self.multimodal_config,
                                   parallel_config=self.parallel_config,
                                   scheduler_config=self.scheduler_config,
                                   cache_config=self.cache_config)
561
562

        self.model_memory_usage = m.consumed_memory
563
564
        logger.info("Loading model weights took %.4f GB",
                    self.model_memory_usage / float(2**30))
565
566

        if self.lora_config:
567
            assert supports_lora(self.model), "Model does not support LoRA"
568
569
570
            assert not supports_vision(
                self.model
            ), "To be tested: vision language model with LoRA settings."
571

572
573
            self.lora_manager = LRUCacheWorkerLoRAManager(
                self.scheduler_config.max_num_seqs,
574
575
576
577
578
579
580
581
582
                self.scheduler_config.max_num_batched_tokens,
                self.vocab_size,
                self.lora_config,
                self.device,
                self.model.embedding_modules,
                self.model.embedding_padding_modules,
                max_position_embeddings=self.model.config.
                max_position_embeddings,
            )
583
            self.model = self.lora_manager.create_lora_manager(self.model)
584

585
586
587
588
589
590
591
592
593
        if self.prompt_adapter_config:
            self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens, self.device,
                self.prompt_adapter_config)
            self.model = (
                self.prompt_adapter_manager.create_prompt_adapter_manager(
                    self.model))

594
        if self.kv_cache_dtype == "fp8" and is_hip():
595
596
597
            # Currently only ROCm accepts kv-cache scaling factors
            # via quantization_param_path and this will be deprecated
            # in the future.
598
599
            if self.model_config.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
600
601
602
603
604
605
                    warnings.warn(
                        "Loading kv cache scaling factor from JSON is "
                        "deprecated and will be removed. Please include "
                        "kv cache scaling factors in the model checkpoint.",
                        FutureWarning,
                        stacklevel=2)
606
607
                    self.model.load_kv_cache_scales(
                        self.model_config.quantization_param_path)
608
609
                    logger.info("Loaded KV cache scaling factors from %s",
                                self.model_config.quantization_param_path)
610
                else:
611
612
613
614
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__)
615
            else:
616
617
618
619
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!")
620

621
622
623
624
625
626
627
628
629
630
631
632
633
634
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        from vllm.model_executor.model_loader.loader import ShardedStateLoader
        ShardedStateLoader.save_model(
            self.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

635
636
637
638
639
640
641
642
643
644
    def save_tensorized_model(
        self,
        tensorizer_config: TensorizerConfig,
    ) -> None:
        from vllm.model_executor.model_loader.loader import TensorizerLoader
        TensorizerLoader.save_model(
            self.model,
            tensorizer_config=tensorizer_config,
        )

645
646
    def get_max_block_per_batch(self) -> int:
        block_size = self.block_size
647
        return (self.max_seq_len_to_capture + block_size - 1) // block_size
648

649
    def _prepare_model_input_tensors(
650
651
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
Mor Zusman's avatar
Mor Zusman committed
652
        finished_requests_ids: Optional[List[str]] = None
653
654
655
656
    ) -> TModelInputForGPU:
        """Helper method to prepare the model input based on a given sequence
        group. Prepares metadata needed for the base model forward pass but not
        metadata for possible additional steps, e.g., sampling.
657
658
659
660
661
662
663
664
665
666
667

        The API assumes seq_group_metadata_list is sorted by prefill -> decode.

        The result tensors and data structure also batches input in prefill
        -> decode order. For example,

        - input_tokens[:num_prefill_tokens] contains prefill tokens.
        - input_tokens[num_prefill_tokens:] contains decode tokens.

        If cuda graph is required, this API automatically pads inputs.
        """
668
669
        builder = ModelInputForGPUBuilder(weakref.proxy(self),
                                          finished_requests_ids)
670
        for seq_group_metadata in seq_group_metadata_list:
671
672
            builder.add_seq_group(seq_group_metadata)
        return builder.build()  # type: ignore
673

674
675
676
    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
677
        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
678
679
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs
680
681
682
683
        # This represents the maximum number of different requests
        # that will have unique loras, an therefore the max amount of memory
        # consumption create dummy lora request copies from the lora request
        # passed in, which contains a lora from the lora warmup path.
684
685
        dummy_lora_requests: List[LoRARequest] = []
        dummy_lora_requests_per_seq: List[LoRARequest] = []
686
        if self.lora_config:
687
            assert self.lora_manager is not None
688
689
690
691
692
693
            with self.lora_manager.dummy_lora_cache():
                for idx in range(self.lora_config.max_loras):
                    lora_id = idx + 1
                    dummy_lora_request = LoRARequest(
                        lora_name=f"warmup_{lora_id}",
                        lora_int_id=lora_id,
694
                        lora_path="/not/a/real/path",
695
696
697
698
699
700
701
702
                    )
                    self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                     rank=LORA_WARMUP_RANK)
                    dummy_lora_requests.append(dummy_lora_request)
                dummy_lora_requests_per_seq = [
                    dummy_lora_requests[idx % len(dummy_lora_requests)]
                    for idx in range(max_num_seqs)
                ]
703

704
705
706
        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
707
708
709
710
711
712
        # Additional GPU memory may be needed for vision encoding, which needs
        # to be accounted for when calculating the GPU blocks for
        # vLLM blocker manager.
        # To exercise the worst scenario for GPU memory consumption,
        # the number of seqs (batch_size) is chosen to maximize the number
        # of images processed.
713
714
        model_config = self.model_config

715
        if supports_vision(self.model):
716
717
718
719
720
721
722
723
724
725
726
727
728
            max_mm_tokens = MULTIMODAL_REGISTRY \
                .get_max_multimodal_tokens(model_config)
            max_num_seqs_orig = max_num_seqs
            max_num_seqs = min(max_num_seqs,
                               max_num_batched_tokens // max_mm_tokens)
            if max_num_seqs < 1:
                expr = (f"min({max_num_seqs_orig}, "
                        f"{max_num_batched_tokens} // {max_mm_tokens})")
                logger.warning(
                    "Computed max_num_seqs (%s) to be less than 1. "
                    "Setting it to the minimum value of 1.", expr)
                max_num_seqs = 1

729
        batch_size = 0
730
731
732
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
733
            batch_size += seq_len
734

735
736
            seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
                .dummy_data_for_profiling(model_config, seq_len)
737
738
739
740
741

            # Having more tokens is over-conservative but otherwise fine
            assert len(seq_data.prompt_token_ids) >= seq_len, (
                f"Expected at least {seq_len} dummy tokens for profiling, "
                f"but got: {len(seq_data.prompt_token_ids)}")
742

743
744
745
746
747
748
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
749
750
                lora_request=dummy_lora_requests_per_seq[group_id]
                if dummy_lora_requests_per_seq else None,
751
                multi_modal_data=dummy_multi_modal_data,
752
753
754
755
756
            )
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
757
        kv_caches = [None] * num_layers
Mor Zusman's avatar
Mor Zusman committed
758
759
760
        finished_requests_ids = [seq.request_id for seq in seqs]
        model_input = self.prepare_model_input(
            seqs, finished_requests_ids=finished_requests_ids)
761
762
763
764
765
766
767
        intermediate_tensors = None
        if not get_pp_group().is_first_rank:
            intermediate_tensors = self.model.make_empty_intermediate_tensors(
                batch_size=batch_size,
                dtype=self.model_config.dtype,
                device=self.device)
        self.execute_model(model_input, kv_caches, intermediate_tensors)
768
        torch.cuda.synchronize()
769
770
        return

771
    def remove_all_loras(self):
772
773
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
774
        self.lora_manager.remove_all_adapters()
775

776
    def set_active_loras(self, lora_requests: Set[LoRARequest],
777
778
779
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
780
        self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
781
782
783
784

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
785
        return self.lora_manager.add_adapter(lora_request)
786
787
788
789

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
790
        return self.lora_manager.remove_adapter(lora_id)
791
792
793
794

    def pin_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
795
        return self.lora_manager.pin_adapter(lora_id)
796
797
798
799

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
        return self.lora_manager.list_adapters()

    def remove_all_prompt_adapters(self):
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        self.prompt_adapter_manager.remove_all_adapters()

    def set_active_prompt_adapters(
            self, prompt_adapter_requests: Set[PromptAdapterRequest],
            prompt_adapter_mapping: PromptAdapterMapping) -> None:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        self.prompt_adapter_manager.set_active_adapters(
            prompt_adapter_requests, prompt_adapter_mapping)

    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        return self.prompt_adapter_manager.add_adapter(prompt_adapter_request)

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id)

    def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id)

    def list_prompt_adapters(self) -> Set[int]:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        return self.prompt_adapter_manager.list_adapters()
835

836
    @torch.inference_mode()
837
    def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
838
839
840
841
842
843
844
845
846
847
848
849
        """Cuda graph capture a model.

        Note that CUDA graph's performance gain is negligible if number
        of batched tokens are larger than 200. And since CUDA graph
        requires fixed sized tensors, supporting large/variable batch
        size requires high GPU memory overhead. Thus, vLLM only captures
        decoding requests. Mixed batch (chunked prefill + decoding) or
        prefill requests are not captured.

        Since it is used for decoding-only, it assumes there's only 1 token
        per sequence in the batch.
        """
850
851
852
853
854
        assert not self.model_config.enforce_eager
        logger.info("Capturing the model for CUDA graphs. This may lead to "
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
                    "use '--enforce-eager' in the CLI.")
855
856
        logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
                    "If you are running out of memory, consider decreasing "
857
858
859
                    "`gpu_memory_utilization` or enforcing eager mode. "
                    "You can also reduce the `max_num_seqs` as needed "
                    "to decrease memory usage.")
860
861
862
863
        start_time = time.perf_counter()

        # Prepare dummy inputs. These will be reused for all batch sizes.
        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
864
865
866
        input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
867
        slot_mapping.fill_(_PAD_SLOT_ID)
868
        seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
869
        block_tables = torch.from_numpy(self.graph_block_tables).cuda()
870
871
872
873
874
875
        intermediate_inputs = None
        if not get_pp_group().is_first_rank:
            intermediate_inputs = self.model.make_empty_intermediate_tensors(
                batch_size=max_batch_size,
                dtype=self.model_config.dtype,
                device=self.device)
876

877
878
        # Prepare buffer for outputs. These will be reused for all batch sizes.
        # It will be filled after the first graph capture.
879
880
881
        hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [
            None
        ] * self.parallel_config.pipeline_parallel_size
882

883
884
885
886
887
888
        graph_batch_size = _get_graph_batch_size(
            self.scheduler_config.max_num_seqs)
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
        if self.attn_backend.get_name() == "flashinfer":
            # For flashinfer, different batch sizes will share the
            # same workspace buffer.
            decode_workspace_buffer = \
            torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE,
                                                dtype=torch.uint8,
                                              device=self.device)
            indices_buffer = torch.empty(max_batch_size *
                                         self.cache_config.num_gpu_blocks,
                                         dtype=torch.int32,
                                         device=self.device)
            indptr_buffer = torch.empty(max_batch_size + 1,
                                        dtype=torch.int32,
                                        device=self.device)
            last_page_len_buffer = torch.empty(max_batch_size,
                                               dtype=torch.int32,
                                               device=self.device)

907
        with graph_capture() as graph_capture_context:
908
909
            # NOTE: Capturing the largest batch size first may help reduce the
            # memory usage of CUDA graph.
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
            for virtual_engine in range(
                    self.parallel_config.pipeline_parallel_size):
                for batch_size in reversed(batch_size_capture_list):
                    if self.attn_backend.get_name() == "flashinfer":
                        indptr_buffer = indptr_buffer[:batch_size + 1]
                        last_page_len_buffer = last_page_len_buffer[:
                                                                    batch_size]

                        num_qo_heads = (
                            self.model_config.get_num_attention_heads(
                                self.parallel_config))
                        num_kv_heads = self.model_config.get_num_kv_heads(
                            self.parallel_config)
                        if num_qo_heads // num_kv_heads >= 4:
                            use_tensor_cores = True
                        else:
                            use_tensor_cores = False
                        decode_wrapper = \
                            CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
                            decode_workspace_buffer, indptr_buffer,
                            indices_buffer, last_page_len_buffer, "NHD",
                            use_tensor_cores)
                        kv_cache_dtype = get_kv_cache_torch_dtype(
                            self.kv_cache_dtype, self.model_config.dtype)

                        paged_kv_indptr_tensor_host = torch.arange(
                            0, batch_size + 1, dtype=torch.int32)
                        paged_kv_indices_tensor_host = torch.arange(
                            0, batch_size, dtype=torch.int32)
                        paged_kv_last_page_len_tensor_host = torch.full(
                            (batch_size, ), self.block_size, dtype=torch.int32)
                        query_start_loc_host = torch.arange(0,
                                                            batch_size + 1,
                                                            dtype=torch.int32)

                        attn_metadata = self.attn_backend.make_metadata(
                            num_prefills=0,
                            slot_mapping=slot_mapping[:batch_size],
                            num_prefill_tokens=0,
                            num_decode_tokens=batch_size,
                            max_prefill_seq_len=0,
                            block_tables=block_tables,
                            paged_kv_indptr=paged_kv_indptr_tensor_host,
                            paged_kv_indices=paged_kv_indices_tensor_host,
                            paged_kv_last_page_len=
                            paged_kv_last_page_len_tensor_host,
                            num_qo_heads=num_qo_heads,
                            num_kv_heads=num_kv_heads,
                            head_dim=self.model_config.get_head_size(),
                            page_size=self.block_size,
                            seq_start_loc=None,
                            query_start_loc=query_start_loc_host,
                            device=self.device,
                            data_type=kv_cache_dtype,
                            use_cuda_graph=True,
                            decode_wrapper=decode_wrapper,
                            prefill_wrapper=None)
                        attn_metadata.begin_forward()
968
                    else:
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
                        attn_metadata = self.attn_backend.make_metadata(
                            num_prefills=0,
                            num_prefill_tokens=0,
                            num_decode_tokens=batch_size,
                            slot_mapping=slot_mapping[:batch_size],
                            seq_lens=None,
                            seq_lens_tensor=seq_lens[:batch_size],
                            max_query_len=None,
                            max_prefill_seq_len=0,
                            max_decode_seq_len=self.max_seq_len_to_capture,
                            query_start_loc=None,
                            seq_start_loc=None,
                            context_lens_tensor=None,
                            block_tables=block_tables[:batch_size],
                            use_cuda_graph=True,
                        )

                    if self.lora_config:
                        lora_mapping = LoRAMapping(
                            [0] * batch_size,
                            [0] * batch_size,
                        )
                        self.set_active_loras(set(), lora_mapping)

993
994
995
996
997
998
999
1000
                    if self.prompt_adapter_config:
                        prompt_adapter_mapping = PromptAdapterMapping(
                            [-1] * batch_size,
                            [-1] * batch_size,
                        )
                        self.set_active_prompt_adapters(
                            set(), prompt_adapter_mapping)

1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
                    graph_runner = CUDAGraphRunner(
                        self.model, self.attn_backend.get_name())

                    if self.attn_backend.get_name() == "flashinfer":
                        graph_runner.flashinfer_indptr_buffer = indptr_buffer
                        graph_runner.flashinfer_indices_buffer = indices_buffer
                        graph_runner.flashinfer_last_page_len_buffer = \
                            last_page_len_buffer
                        graph_runner.flashinfer_decode_workspace_buffer = \
                                decode_workspace_buffer
                        graph_runner.flashinfer_decode_wrapper = \
                            decode_wrapper

Mor Zusman's avatar
Mor Zusman committed
1014
1015
                    capture_inputs = {
                        "input_ids":
1016
                        input_tokens[:batch_size],
Mor Zusman's avatar
Mor Zusman committed
1017
                        "positions":
1018
                        input_positions[:batch_size],
Mor Zusman's avatar
Mor Zusman committed
1019
                        "hidden_or_intermediate_states":
1020
1021
1022
1023
1024
                        hidden_or_intermediate_states[
                            virtual_engine]  # type: ignore
                        [:batch_size]
                        if hidden_or_intermediate_states[virtual_engine]
                        is not None else None,
Mor Zusman's avatar
Mor Zusman committed
1025
                        "intermediate_inputs":
1026
1027
                        intermediate_inputs[:batch_size]
                        if intermediate_inputs is not None else None,
Mor Zusman's avatar
Mor Zusman committed
1028
                        "kv_caches":
1029
                        kv_caches[virtual_engine],
Mor Zusman's avatar
Mor Zusman committed
1030
                        "attn_metadata":
1031
                        attn_metadata,
Mor Zusman's avatar
Mor Zusman committed
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
                        "memory_pool":
                        self.graph_memory_pool,
                        "stream":
                        graph_capture_context.stream
                    }
                    if self.has_seqlen_agnostic:
                        # Only used by Mamba-based models CUDA graph atm (Jamba)
                        capture_inputs.update({
                            "seqlen_agnostic_capture_inputs":
                            self.model.get_seqlen_agnostic_capture_inputs(
                                batch_size)
                        })
                    graph_runner.capture(**capture_inputs)
1045
1046
1047
                    self.graph_memory_pool = graph_runner.graph.pool()
                    self.graph_runners[virtual_engine][batch_size] = (
                        graph_runner)
1048
1049
1050
1051

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        # This usually takes < 10 seconds.
1052
        logger.info("Graph capturing finished in %.0f secs.", elapsed_time)
1053

1054
1055
1056
1057
    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

1058

1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
    """
    GPU model runner with sampling step.
    """
    _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
        ModelInputForGPUWithSamplingMetadata)

    def make_model_input_from_broadcasted_tensor_dict(
        self,
        tensor_dict: Dict[str, Any],
    ) -> ModelInputForGPUWithSamplingMetadata:
1070
        model_input = \
1071
1072
1073
            ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
                tensor_dict,
                attn_backend=self.attn_backend,
1074
1075
            )
        return model_input
1076
1077
1078
1079

    def prepare_model_input(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
1080
        virtual_engine: int = 0,
Mor Zusman's avatar
Mor Zusman committed
1081
        finished_requests_ids: Optional[List[str]] = None
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
    ) -> ModelInputForGPUWithSamplingMetadata:
        """Prepare the model input based on a given sequence group, including
        metadata for the sampling step.

        The API assumes seq_group_metadata_list is sorted by prefill -> decode.

        The result tensors and data structure also batches input in prefill
        -> decode order. For example,

        - input_tokens[:num_prefill_tokens] contains prefill tokens.
        - input_tokens[num_prefill_tokens:] contains decode tokens.

        If cuda graph is required, this API automatically pads inputs.
        """
        model_input = self._prepare_model_input_tensors(
Mor Zusman's avatar
Mor Zusman committed
1097
            seq_group_metadata_list, finished_requests_ids)
1098
1099
1100
1101
1102
1103
1104
1105
1106
        sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
                                                     model_input.seq_lens,
                                                     model_input.query_lens,
                                                     self.device,
                                                     self.pin_memory)
        is_prompt = (seq_group_metadata_list[0].is_prompt
                     if seq_group_metadata_list else None)
        return dataclasses.replace(model_input,
                                   sampling_metadata=sampling_metadata,
1107
1108
                                   is_prompt=is_prompt,
                                   virtual_engine=virtual_engine)
1109
1110
1111
1112
1113
1114

    @torch.inference_mode()
    def execute_model(
        self,
        model_input: ModelInputForGPUWithSamplingMetadata,
        kv_caches: List[torch.Tensor],
1115
        intermediate_tensors: Optional[IntermediateTensors] = None,
1116
        num_steps: int = 1,
1117
    ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
1118
1119
1120
        if num_steps > 1:
            raise ValueError("num_steps > 1 is not supported in ModelRunner")

1121
1122
1123
1124
1125
1126
        if self.lora_config:
            assert model_input.lora_requests is not None
            assert model_input.lora_mapping is not None
            self.set_active_loras(model_input.lora_requests,
                                  model_input.lora_mapping)

1127
1128
1129
1130
1131
1132
1133
        if self.prompt_adapter_config:
            assert model_input.prompt_adapter_requests is not None
            assert model_input.prompt_adapter_mapping is not None
            self.set_active_prompt_adapters(
                model_input.prompt_adapter_requests,
                model_input.prompt_adapter_mapping)

1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
        if self.attn_backend.get_name() == "flashinfer":
            assert model_input.attn_metadata is not None
            assert model_input.input_tokens is not None
            if self.flashinfer_decode_workspace_buffer is None:
                self.flashinfer_decode_workspace_buffer = torch.empty(
                    FLASHINFER_WORKSPACE_BUFFER_SIZE,
                    dtype=torch.uint8,
                    device=self.device)
                self.flashinfer_decode_wrapper = \
                    BatchDecodeWithPagedKVCacheWrapper(
                    self.flashinfer_decode_workspace_buffer, "NHD")
                self.flashinfer_prefill_workspace_buffer = torch.empty(
                    FLASHINFER_WORKSPACE_BUFFER_SIZE,
                    dtype=torch.uint8,
                    device=self.device)
                self.flashinfer_prefill_wrapper = \
                    BatchPrefillWithPagedKVCacheWrapper(
                    self.flashinfer_prefill_workspace_buffer, "NHD")

            model_input.attn_metadata.prefill_wrapper = \
                self.flashinfer_prefill_wrapper
            if model_input.attn_metadata.use_cuda_graph:
                batch_size = model_input.input_tokens.shape[0]
                model_input.attn_metadata.decode_wrapper = self.graph_runners[
1158
1159
                    model_input.
                    virtual_engine][batch_size].flashinfer_decode_wrapper
1160
1161
1162
1163
1164
            else:
                model_input.attn_metadata.decode_wrapper = \
                    self.flashinfer_decode_wrapper
            model_input.attn_metadata.begin_forward()

1165
1166
1167
1168
        # Currently cuda graph is only supported by the decode phase.
        assert model_input.attn_metadata is not None
        prefill_meta = model_input.attn_metadata.prefill_metadata
        decode_meta = model_input.attn_metadata.decode_metadata
1169
1170
1171
        # TODO(andoorve): We can remove this once all
        # virtual engines share the same kv cache.
        virtual_engine = model_input.virtual_engine
1172
1173
1174
        if prefill_meta is None and decode_meta.use_cuda_graph:
            assert model_input.input_tokens is not None
            graph_batch_size = model_input.input_tokens.shape[0]
1175
1176
            model_executable = self.graph_runners[virtual_engine][
                graph_batch_size]
1177
1178
1179
1180
        else:
            model_executable = self.model

        multi_modal_kwargs = model_input.multi_modal_kwargs or {}
Mor Zusman's avatar
Mor Zusman committed
1181
1182
1183
1184
        seqlen_agnostic_kwargs = {
            "finished_requests_ids": model_input.finished_requests_ids,
            "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
        } if self.has_seqlen_agnostic else {}
1185
        hidden_or_intermediate_states = model_executable(
1186
1187
1188
1189
            input_ids=model_input.input_tokens,
            positions=model_input.input_positions,
            kv_caches=kv_caches,
            attn_metadata=model_input.attn_metadata,
1190
            intermediate_tensors=intermediate_tensors,
1191
            **multi_modal_kwargs,
Mor Zusman's avatar
Mor Zusman committed
1192
            **seqlen_agnostic_kwargs)
1193

1194
1195
1196
1197
1198
        # Compute the logits in the last pipeline stage.
        if not get_pp_group().is_last_rank:
            return hidden_or_intermediate_states

        logits = self.model.compute_logits(hidden_or_intermediate_states,
1199
1200
1201
                                           model_input.sampling_metadata)

        if not self.is_driver_worker:
1202
            return []
1203
1204
1205
1206
1207
1208
1209
1210
1211

        # Sample the next token.
        output: SamplerOutput = self.model.sample(
            logits=logits,
            sampling_metadata=model_input.sampling_metadata,
        )

        if self.return_hidden_states:
            # we only need to pass hidden states of most recent token
1212
1213
            assert model_input.sampling_metadata is not None
            indices = model_input.sampling_metadata.selected_token_indices
1214
            if model_input.is_prompt:
1215
1216
                hidden_states = hidden_or_intermediate_states.index_select(
                    0, indices)
1217
            elif decode_meta.use_cuda_graph:
1218
1219
1220
                hidden_states = hidden_or_intermediate_states[:len(indices)]
            else:
                hidden_states = hidden_or_intermediate_states
1221

1222
1223
            output.hidden_states = hidden_states

1224
        return [output]
1225
1226


1227
1228
class CUDAGraphRunner:

1229
    def __init__(self, model: nn.Module, backend_name: str):
1230
        self.model = model
1231
1232
        self.backend_name = backend_name

1233
1234
1235
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

1236
1237
        self._graph: Optional[torch.cuda.CUDAGraph] = None

1238
1239
1240
1241
1242
1243
1244
        self.flashinfer_decode_workspace_buffer: Optional[torch.Tensor] = None
        self.flashinfer_indptr_buffer: Optional[torch.Tensor] = None
        self.flashinfer_indices_buffer: Optional[torch.Tensor] = None
        self.flashinfer_last_page_len_buffer: Optional[torch.Tensor] = None
        self.flashinfer_decode_wrapper: Optional[
            CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None

1245
1246
1247
1248
1249
    @property
    def graph(self):
        assert self._graph is not None
        return self._graph

1250
1251
1252
1253
    def capture(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1254
1255
1256
        hidden_or_intermediate_states: Optional[Union[IntermediateTensors,
                                                      torch.Tensor]],
        intermediate_inputs: Optional[IntermediateTensors],
1257
1258
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1259
1260
        memory_pool: Optional[Tuple[int, int]],
        stream: torch.cuda.Stream,
1261
        **kwargs,
1262
    ) -> Union[torch.Tensor, IntermediateTensors]:
1263
        assert self._graph is None
1264
        # Run the model a few times without capturing the graph.
1265
1266
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
1267
1268
1269
1270
1271
1272
1273
        # Note one iteration is not enough for torch.jit.script
        for _ in range(_NUM_WARMUP_ITERS):
            self.model(
                input_ids,
                positions,
                kv_caches,
                attn_metadata,
1274
                intermediate_inputs,
1275
1276
                **kwargs,
            )
1277
1278
1279
1280
1281
        torch.cuda.synchronize()

        # Capture the graph.
        self._graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
1282
            output_hidden_or_intermediate_states = self.model(
1283
1284
1285
                input_ids,
                positions,
                kv_caches,
1286
                attn_metadata,
1287
                intermediate_inputs,
1288
                **kwargs,
1289
            )
1290
1291
1292
1293
1294
1295
1296
1297
            if hidden_or_intermediate_states is not None:
                if get_pp_group().is_last_rank:
                    hidden_or_intermediate_states.copy_(
                        output_hidden_or_intermediate_states)
                else:
                    for key in hidden_or_intermediate_states.tensors:
                        hidden_or_intermediate_states[key].copy_(
                            output_hidden_or_intermediate_states[key])
1298
            else:
1299
1300
1301
1302
                hidden_or_intermediate_states = (
                    output_hidden_or_intermediate_states)

            del output_hidden_or_intermediate_states
1303
1304
1305
            # make sure `output_hidden_states` is deleted
            # in the graph's memory pool
            gc.collect()
1306
1307
1308
        torch.cuda.synchronize()

        # Save the input and output buffers.
1309
1310
1311
1312
1313
1314
        if self.backend_name == "flashinfer":
            self.input_buffers = {
                "input_ids": input_ids,
                "positions": positions,
                "kv_caches": kv_caches,
                "slot_mapping": attn_metadata.slot_mapping,
Mor Zusman's avatar
Mor Zusman committed
1315
                **kwargs,
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
            }
        else:
            self.input_buffers = {
                "input_ids": input_ids,
                "positions": positions,
                "kv_caches": kv_caches,
                "slot_mapping": attn_metadata.slot_mapping,
                "seq_lens_tensor":
                attn_metadata.decode_metadata.seq_lens_tensor,
                "block_tables": attn_metadata.decode_metadata.block_tables,
Mor Zusman's avatar
Mor Zusman committed
1326
                **kwargs,
1327
            }
1328
1329
1330
1331
1332
1333
1334
1335
1336
        if intermediate_inputs is not None:
            self.input_buffers.update(intermediate_inputs.tensors)
        if get_pp_group().is_last_rank:
            self.output_buffers = {
                "hidden_states": hidden_or_intermediate_states
            }
        else:
            self.output_buffers = hidden_or_intermediate_states
        return hidden_or_intermediate_states
1337
1338
1339
1340
1341

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1342
1343
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1344
        intermediate_tensors: Optional[IntermediateTensors],
1345
        **kwargs,
1346
1347
1348
1349
1350
    ) -> torch.Tensor:
        # KV caches are fixed tensors, so we don't need to copy them.
        del kv_caches

        # Copy the input tensors to the input buffers.
1351
1352
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
        self.input_buffers["positions"].copy_(positions, non_blocking=True)
1353
        self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
1354
                                                 non_blocking=True)
1355
1356
1357
1358
1359
1360
        if self.backend_name != "flashinfer":
            self.input_buffers["seq_lens_tensor"].copy_(
                attn_metadata.decode_metadata.seq_lens_tensor,
                non_blocking=True)
            self.input_buffers["block_tables"].copy_(
                attn_metadata.decode_metadata.block_tables, non_blocking=True)
Mor Zusman's avatar
Mor Zusman committed
1361
1362
1363
        if "seqlen_agnostic_capture_inputs" in self.input_buffers:
            self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
                                                      **kwargs)
1364
1365
1366
1367
        if intermediate_tensors is not None:
            for key in intermediate_tensors.tensors:
                self.input_buffers[key].copy_(intermediate_tensors[key],
                                              non_blocking=True)
1368
1369
        # Run the graph.
        self.graph.replay()
Mor Zusman's avatar
Mor Zusman committed
1370
1371
1372
        if "seqlen_agnostic_capture_inputs" in self.input_buffers:
            self.model.copy_outputs_after_cuda_graphs(self.input_buffers,
                                                      **kwargs)
1373
        # Return the output tensor.
1374
1375
1376
1377
        if get_pp_group().is_last_rank:
            return self.output_buffers["hidden_states"]

        return self.output_buffers
1378
1379
1380
1381

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

1382

1383
def _get_graph_batch_size(batch_size: int) -> int:
1384
1385
1386
1387
1388
    """Returns the padded batch size given actual batch size.

    Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
    2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
    """
1389
1390
1391
1392
1393
    if batch_size <= 2:
        return batch_size
    elif batch_size <= 4:
        return 4
    else:
1394
1395
        return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
                _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
1396
1397


1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
def _is_block_tables_empty(block_tables: Union[None, Dict]):
    """
    Check if block_tables is None or a dictionary with all None values.
    """
    if block_tables is None:
        return True
    if isinstance(block_tables, dict) and all(
            value is None for value in block_tables.values()):
        return True
    return False