model_runner.py 80.5 KB
Newer Older
1
import dataclasses
2
import gc
3
import inspect
4
import itertools
5
import time
6
import warnings
7
import weakref
8
from dataclasses import dataclass
9
10
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
                    Tuple, Type, TypeVar, Union)
11

12
import numpy as np
13
import torch
14
import torch.distributed
15
import torch.nn as nn
16
import torch.nn.functional as F
17

18
import vllm.envs as envs
19
from vllm.attention import AttentionMetadata, get_attn_backend
20
21
from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState
22
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
23
24
                         ModelConfig, ObservabilityConfig, ParallelConfig,
                         PromptAdapterConfig, SchedulerConfig)
25
from vllm.core.scheduler import SchedulerOutputs
26
from vllm.distributed import get_pp_group
27
from vllm.distributed.parallel_state import graph_capture
28
from vllm.inputs import INPUT_REGISTRY, InputRegistry
29
from vllm.logger import init_logger
30
31
32
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
33
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
34
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
35
from vllm.model_executor.layers.sampler import SamplerOutput
36
from vllm.model_executor.model_loader import get_model
37
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
38
from vllm.model_executor.models.interfaces import (supports_lora,
39
                                                   supports_multimodal)
40
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
41
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
42
                             MultiModalInputs, MultiModalRegistry)
43
44
45
46
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import (
    LRUCacheWorkerPromptAdapterManager)
47
from vllm.sampling_params import SamplingParams
48
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
49
from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d,
50
51
                        flatten_2d_lists, is_hip, is_pin_memory_available,
                        supports_dynamo)
52
from vllm.worker.model_runner_base import (
53
    ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
54
55
56
    _add_attn_metadata_broadcastable_dict,
    _add_sampling_metadata_broadcastable_dict,
    _init_attn_metadata_from_tensor_dict,
57
    _init_sampling_metadata_from_tensor_dict, dump_input_when_exception)
58
59
60

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
61
62
63

logger = init_logger(__name__)

64
LORA_WARMUP_RANK = 8
65
_BATCH_SIZE_ALIGNMENT = 8
66
67
68
69
70
# all the token sizes that **can** be captured by cudagraph.
# they can be arbitrarily large.
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
# the actual sizes to capture will be determined by the model,
# depending on the model's max_num_seqs.
71
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
72
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
73
    _BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
74
]
75
_NUM_WARMUP_ITERS = 2
76

77
78
TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")

79
80
81
82
# For now, bump up cache limits for recompilations during CUDA graph warmups.
torch._dynamo.config.cache_size_limit = 128
torch._dynamo.config.accumulated_cache_size_limit = 128

83

84
@dataclass(frozen=True)
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class ModelInputForGPU(ModelRunnerInputBase):
    """
    This base class contains metadata needed for the base model forward pass
    but not metadata for possible additional steps, e.g., sampling. Model
    runners that run additional steps should subclass this method to add
    additional fields.
    """
    input_tokens: Optional[torch.Tensor] = None
    input_positions: Optional[torch.Tensor] = None
    seq_lens: Optional[List[int]] = None
    query_lens: Optional[List[int]] = None
    lora_mapping: Optional["LoRAMapping"] = None
    lora_requests: Optional[Set[LoRARequest]] = None
    attn_metadata: Optional["AttentionMetadata"] = None
99
100
    prompt_adapter_mapping: Optional[PromptAdapterMapping] = None
    prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None
101
    multi_modal_kwargs: Optional[BatchedTensorInputs] = None
Mor Zusman's avatar
Mor Zusman committed
102
103
    request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
    finished_requests_ids: Optional[List[str]] = None
104
    virtual_engine: int = 0
105
    async_callback: Optional[Callable] = None
106
107
    seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
    scheduler_outputs: Optional[SchedulerOutputs] = None
108
109
110
111
112
113
114
115

    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
            "lora_requests": self.lora_requests,
            "lora_mapping": self.lora_mapping,
            "multi_modal_kwargs": self.multi_modal_kwargs,
116
117
            "prompt_adapter_mapping": self.prompt_adapter_mapping,
            "prompt_adapter_requests": self.prompt_adapter_requests,
118
            "virtual_engine": self.virtual_engine,
Mor Zusman's avatar
Mor Zusman committed
119
120
            "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
            "finished_requests_ids": self.finished_requests_ids,
121
122
123
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        return tensor_dict
124
125

    @classmethod
126
127
128
129
130
131
132
133
134
135
136
    def from_broadcasted_tensor_dict(
        cls: Type[TModelInputForGPU],
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> TModelInputForGPU:
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


137
@dataclass(frozen=True)
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
    """
    Used by the ModelRunner.
    """
    sampling_metadata: Optional["SamplingMetadata"] = None
    # Used for speculative decoding. We do not broadcast it because it is only
    # used by the driver worker.
    is_prompt: Optional[bool] = None

    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
            "input_positions": self.input_positions,
            "lora_requests": self.lora_requests,
            "lora_mapping": self.lora_mapping,
            "multi_modal_kwargs": self.multi_modal_kwargs,
154
155
            "prompt_adapter_mapping": self.prompt_adapter_mapping,
            "prompt_adapter_requests": self.prompt_adapter_requests,
156
            "virtual_engine": self.virtual_engine,
Mor Zusman's avatar
Mor Zusman committed
157
158
            "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
            "finished_requests_ids": self.finished_requests_ids,
159
160
161
162
163
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        _add_sampling_metadata_broadcastable_dict(tensor_dict,
                                                  self.sampling_metadata)
        return tensor_dict
164

165
166
167
168
169
170
171
172
173
174
175
176
177
    @classmethod
    def from_broadcasted_tensor_dict(
        cls,
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "ModelInputForGPUWithSamplingMetadata":
        tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


178
class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
179
180
    """Build ModelInputForGPU from SequenceGroupMetadata."""

181
182
183
    # Note: ideally we would be using a dataclass(kw_only=True)
    # here, so that this can be subclassed easily,
    # but kw_only is not supported in python<3.10.
184
185
    class InterDataForSeqGroup:
        """Intermediate data for the current sequence group."""
186

187
188
189
        def simple_reinit(self):
            self.input_tokens[0].clear()  # type: ignore
            self.input_positions[0].clear()  # type: ignore
190
            self.mrope_input_positions = None  # type: ignore
191
192
193
194
195
196
197
198
199
200
201
            self.seq_lens[0] = 0  # type: ignore
            self.orig_seq_lens[0] = 0  # type: ignore
            self.query_lens[0] = 0  # type: ignore
            self.context_lens[0] = 0  # type: ignore
            self.curr_sliding_window_blocks[0] = 0  # type: ignore
            self.lora_index_mapping.clear()  # type: ignore
            self.lora_prompt_mapping.clear()  # type: ignore
            self.lora_requests.clear()  # type: ignore
            self.prompt_adapter_index_mapping.clear()  # type: ignore
            self.prompt_adapter_prompt_mapping.clear()  # type: ignore

202
203
204
205
206
207
208
209
210
211
212
213
214
215
        def __init__(
            self,
            *,
            # From sequence group metadata.
            request_id: str,
            seq_ids: List[int],
            is_prompt: bool,
            block_tables: Optional[Dict[int, List[int]]],
            computed_block_nums: List[int],
            n_seqs: int = 0,

            # Input tokens and positions.
            input_tokens: Optional[List[List[int]]] = None,
            input_positions: Optional[List[List[int]]] = None,
216
            mrope_input_positions: Optional[List[List[List[int]]]] = None,
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

            # The sequence length (may be capped to the sliding window).
            seq_lens: Optional[List[int]] = None,
            # The original sequence length (before applying sliding window).
            # This is used to compute slot mapping.
            orig_seq_lens: Optional[List[int]] = None,
            # The query length.
            query_lens: Optional[List[int]] = None,
            # The number of tokens that are already computed.
            context_lens: Optional[List[int]] = None,
            # The current sliding window block.
            curr_sliding_window_blocks: Optional[List[int]] = None,

            # LoRA inputs.
            lora_index_mapping: Optional[List[List[int]]] = None,
            lora_prompt_mapping: Optional[List[List[int]]] = None,
            lora_requests: Optional[Set[LoRARequest]] = None,

            # Prompt adapter inputs.
            prompt_adapter_index_mapping: Optional[List[int]] = None,
            prompt_adapter_prompt_mapping: Optional[List[int]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,

            # Multi-modal inputs.
            multi_modal_inputs: Optional[MultiModalInputs] = None,

            # Whether the prefix cache is hit (prefill only).
            prefix_cache_hit: bool = False,
245
246
            reinit: bool = False,
            reinit_use_defaults: bool = False,
247
            encoder_seq_len: int = 0,
248
        ):
249
250
251
252
253
254
255
            if reinit:
                assert len(self.seq_ids) == len(seq_ids)  # type: ignore
                for i, seq_id in enumerate(seq_ids):
                    self.seq_ids[i] = seq_id  # type: ignore
            else:
                self.seq_ids = seq_ids

256
257
258
259
260
            self.request_id = request_id
            self.is_prompt = is_prompt
            self.block_tables = block_tables
            self.computed_block_nums = computed_block_nums
            self.n_seqs = n_seqs
261
            self.encoder_seq_len = encoder_seq_len
262

263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
            if reinit:
                if len(self.seq_ids) == 1 and reinit_use_defaults:
                    self.simple_reinit()
                else:
                    if input_tokens:
                        self.input_tokens = input_tokens
                    else:
                        for seq_id in range(len(self.seq_ids)):
                            self.input_tokens[seq_id].clear()

                    if input_positions:
                        self.input_positions = input_positions
                    else:
                        for seq_id in range(len(self.seq_ids)):
                            self.input_positions[seq_id].clear()

279
280
                    self.mrope_input_positions = None

281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
                    if seq_lens:
                        self.seq_lens = seq_lens
                    else:
                        for seq_id in range(len(self.seq_ids)):
                            self.seq_lens[seq_id] = 0

                    if orig_seq_lens:
                        self.orig_seq_lens = orig_seq_lens
                    else:
                        for seq_id in range(len(self.seq_ids)):
                            self.orig_seq_lens[seq_id] = 0

                    if query_lens:
                        self.query_lens = query_lens
                    else:
                        for seq_id in range(len(self.seq_ids)):
                            self.query_lens[seq_id] = 0

                    if context_lens:
                        self.context_lens = context_lens
                    else:
                        for seq_id in range(len(self.seq_ids)):
                            self.context_lens[seq_id] = 0

                    if curr_sliding_window_blocks:
                        self.curr_sliding_window_blocks = \
                            curr_sliding_window_blocks
                    else:
                        for seq_id in range(len(self.seq_ids)):
                            self.curr_sliding_window_blocks[seq_id] = 0

                    if lora_index_mapping:
                        self.lora_index_mapping = lora_index_mapping
                    else:
                        self.lora_index_mapping.clear()

                    if lora_prompt_mapping:
                        self.lora_prompt_mapping = lora_prompt_mapping
                    else:
                        self.lora_prompt_mapping.clear()

                    if lora_requests:
                        self.lora_requests = lora_requests
                    else:
                        self.lora_requests.clear()

                    if prompt_adapter_index_mapping:
                        self.prompt_adapter_index_mapping = \
                            prompt_adapter_index_mapping
                    else:
                        self.prompt_adapter_index_mapping.clear()

                    if prompt_adapter_prompt_mapping:
                        self.prompt_adapter_prompt_mapping = \
                            prompt_adapter_prompt_mapping
                    else:
                        self.prompt_adapter_prompt_mapping.clear()
            else:
                self.input_tokens = input_tokens or []
                self.input_positions = input_positions or []
341
                self.mrope_input_positions = mrope_input_positions or None
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
                self.seq_lens = seq_lens or []
                self.orig_seq_lens = orig_seq_lens or []
                self.query_lens = query_lens or []
                self.context_lens = context_lens or []
                self.curr_sliding_window_blocks = \
                    curr_sliding_window_blocks or []

                self.lora_index_mapping = lora_index_mapping or []
                self.lora_prompt_mapping = lora_prompt_mapping or []
                self.lora_requests = lora_requests or set()

                self.prompt_adapter_index_mapping = (
                    prompt_adapter_index_mapping or [])
                self.prompt_adapter_prompt_mapping = (
                    prompt_adapter_prompt_mapping or [])

            self.prompt_adapter_request = prompt_adapter_request
359
360
361
            self.multi_modal_inputs = multi_modal_inputs
            self.prefix_cache_hit = prefix_cache_hit

362
363
            self.n_seqs = len(self.seq_ids)

364
365
            if not reinit:
                self.__post_init__()
366
367
368
369
370
371

        def __post_init__(self):
            self.n_seqs = len(self.seq_ids)

            self.input_tokens = [[] for _ in range(self.n_seqs)]
            self.input_positions = [[] for _ in range(self.n_seqs)]
372
            self.mrope_input_positions = None
373
374
375
376
377
378
            self.seq_lens = [0] * self.n_seqs
            self.orig_seq_lens = [0] * self.n_seqs
            self.query_lens = [0] * self.n_seqs
            self.context_lens = [0] * self.n_seqs
            self.curr_sliding_window_blocks = [0] * self.n_seqs

379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
            self.lora_index_mapping = []
            self.lora_prompt_mapping = []

    def gen_inter_data_builder(self, num_seqs: int):
        return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup(
            request_id="",
            seq_ids=[0] * num_seqs,
            is_prompt=True,
            block_tables=None,
            computed_block_nums=[])

    def init_cached_inter_data(self, *args, **kwargs):
        assert len(args) == 0
        assert "seq_ids" in kwargs
        seq_ids = kwargs["seq_ids"]
        num_seqs = len(seq_ids)

        # The inter-data cache is per model_runner
        inter_data_cache = self.runner.inter_data_cache
        if num_seqs not in inter_data_cache:
            inter_data_cache[num_seqs] = PyObjectCache(
                self.gen_inter_data_builder(num_seqs))

        obj = inter_data_cache[num_seqs].get_object()
        obj.__init__(*args, **kwargs)
        return obj

    def reset_cached_inter_data(self):
        for cache in self.runner.inter_data_cache.values():
            cache.reset()
409
410
411
412
413

    def __init__(self,
                 runner: "GPUModelRunnerBase",
                 finished_requests_ids: Optional[List[str]] = None):
        super().__init__()
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
        # Compute functions for each sequence in a sequence group.
        # WARNING: The order of the functions matters!
        self.per_seq_compute_fns = [
            self._compute_lens,
            self._compute_for_prefix_cache_hit,
            self._compute_for_sliding_window,
            self._compute_lora_input,
        ]
        # Compute functions for each sequence group.
        # WARNING: The order of the functions matters!
        self.per_seq_group_compute_fns = [
            self._compute_prompt_adapter_input,
            self._compute_multi_modal_input,
        ]

429
430
431
432
433
434
435
436
437
438
439
440
441
        self.runner = runner
        self.model_input_cls = self.runner._model_input_cls
        self.attn_backend = self.runner.attn_backend
        self.scheduler_config = self.runner.scheduler_config
        self.sliding_window = self.runner.sliding_window
        self.block_size = self.runner.block_size
        self.enable_lora = self.runner.lora_config is not None
        self.enable_prompt_adapter = (self.runner.prompt_adapter_config
                                      is not None)
        self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
        self.finished_requests_ids = finished_requests_ids
        self.decode_only = True

442
443
444
445
        # Intermediate data (data in CPU before going to GPU) for
        # the current sequence group.
        self.inter_data_list: List[
            ModelInputForGPUBuilder.InterDataForSeqGroup] = []
446
447
448

        # Attention metadata inputs.
        self.attn_metadata_builder = self.attn_backend.make_metadata_builder(
449
            weakref.proxy(self))
450
451
452
453
454
455
456
457
458
459

        # Engine/Model configurations.
        self.chunked_prefill_enabled = (
            self.scheduler_config is not None
            and self.scheduler_config.chunked_prefill_enabled)
        if self.sliding_window is not None:
            self.sliding_window_blocks = (
                self.sliding_window + self.block_size - 1) // self.block_size
            self.block_aligned_sliding_window = \
                self.sliding_window_blocks * self.block_size
460
461

        self.is_encoder_decoder_model = self.runner.model_config.is_encoder_decoder_model
462

463
464
465
466
467
468
469
    def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
                      seq_group_metadata: SequenceGroupMetadata):
        """Compute context length, sequence length and tokens
        for the given sequence data.
        """
        seq_data = seq_group_metadata.seq_data[inter_data.seq_ids[seq_idx]]
        token_chunk_size = seq_group_metadata.token_chunk_size
470

471
472
473
474
475
476
477
478
479
480
481
482
483
484
        # Compute context length (the number of tokens that are
        # already computed) and sequence length (total number of tokens).
        seq_len = seq_data.get_len()
        if inter_data.is_prompt:
            context_len = seq_data.get_num_computed_tokens()
        else:
            # get_num_computed_tokens is incorrect for spec decoding.
            # So, we should have a special logic here.
            # TODO(sang): Fix it.
            context_len = seq_len - 1
        seq_len = min(seq_len, context_len + token_chunk_size)

        # Compute tokens.
        if inter_data.is_prompt:
485
486
487
            tokens = seq_data.get_token_ids()
            if context_len != 0 or seq_len < len(tokens):
                tokens = tokens[context_len:seq_len]
488
489
490
        else:
            # Optimization. get_token_ids requires the entire copy of
            # tokens.
491
            tokens = seq_data.get_last_token_id()
492
493
494
495

        inter_data.seq_lens[seq_idx] = seq_len
        inter_data.orig_seq_lens[seq_idx] = seq_len
        inter_data.context_lens[seq_idx] = context_len
496
497
498
499
500

        if isinstance(tokens, list):
            inter_data.input_tokens[seq_idx].extend(tokens)
        else:
            inter_data.input_tokens[seq_idx].append(tokens)
501
502
        
        inter_data.input_positions[seq_idx] = list(range(context_len, seq_len))
503

504
505
506
        inter_data.query_lens[
            seq_idx] = seq_len - context_len if inter_data.is_prompt else 1

507
508
509
510
511
512
513
514
515
516
517
        if seq_data.mrope_position_delta is not None:
            if inter_data.mrope_input_positions is None:
                inter_data.mrope_input_positions = [None] * inter_data.n_seqs

            inter_data.mrope_input_positions[
                seq_idx] = MRotaryEmbedding.get_next_input_positions(
                    seq_data.mrope_position_delta,
                    context_len,
                    seq_len,
                )

518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
    def _compute_for_prefix_cache_hit(
            self, inter_data: InterDataForSeqGroup, seq_idx: int,
            seq_group_metadata: SequenceGroupMetadata):
        """Check if hit prefix cache (i.e., some blocks are already computed).
        If hit, update input tokens and positions to only compute the
        remaining blocks.
        """
        computed_block_nums = inter_data.computed_block_nums

        # Note that prefix caching does not support sliding window.
        prefix_cache_hit = (computed_block_nums is not None
                            and len(computed_block_nums) > 0
                            and self.sliding_window is None
                            and inter_data.is_prompt)
        inter_data.prefix_cache_hit = prefix_cache_hit
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554

        if not prefix_cache_hit:
            return

        assert computed_block_nums is not None
        # The cache hit prompt tokens in this sequence. Note that
        # this may be larger than the sequence length if chunked
        # prefill is enabled.
        prefix_cache_len = len(computed_block_nums) * self.block_size
        # The number of so far computed prompt tokens in this sequence.
        context_len = inter_data.context_lens[seq_idx]
        # The total number of prompt tokens in this sequence.
        # When chunked prefill is enabled, this is the token number of
        # computed chunks + current chunk.
        seq_len = inter_data.seq_lens[seq_idx]
        if prefix_cache_len <= context_len:
            # We already passed the cache hit region,
            # so do normal computation.
            pass
        elif context_len < prefix_cache_len < seq_len:
            # Partial hit. Compute the missing part.
            uncomputed_start = prefix_cache_len - context_len
555
            inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
556
                seq_idx][uncomputed_start:]
557
            inter_data.input_positions[seq_idx] = inter_data.input_positions[
558
559
560
                seq_idx][uncomputed_start:]
            context_len = prefix_cache_len

561
562
563
            inter_data.context_lens[seq_idx] = context_len
            inter_data.query_lens[
                seq_idx] = inter_data.seq_lens[seq_idx] - context_len
564
565
566
567
568
569
570
571
572
573
574
        elif seq_len <= prefix_cache_len:
            # Full hit. Only compute the last token to avoid
            # erroneous behavior. FIXME: Ideally we should directly
            # mark all tokens as computed in the scheduler and do not
            # schedule this sequence, so this case should not happen.
            inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
                seq_idx][-1:]
            inter_data.input_positions[seq_idx] = inter_data.input_positions[
                seq_idx][-1:]
            inter_data.query_lens[seq_idx] = 1
            inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1
575
576
577
578
579
580
581
582
583
584
585
586
587
588

    def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup,
                                    seq_idx: int,
                                    seq_group_metadata: SequenceGroupMetadata):
        """Update seq_len and curr_sliding_window_block for the given
        sequence data (only required by decoding) if sliding window is enabled.
        """
        curr_sliding_window_block = 0
        sliding_seq_len = inter_data.seq_lens[seq_idx]
        if not inter_data.is_prompt and self.sliding_window is not None:
            # TODO(sang): This is a hack to make sliding window work with
            # paged attn. We can remove it if we make paged attn kernel
            # to properly handle slinding window attn.
            curr_sliding_window_block = self.sliding_window_blocks
589
590
            if self.scheduler_config.use_v2_block_manager:
                # number of elements in last block
591
                suff_len = inter_data.seq_lens[seq_idx] % self.block_size
592
                sliding_seq_len = min(
593
594
                    inter_data.seq_lens[seq_idx],
                    self.block_aligned_sliding_window + suff_len)
595
                if suff_len > 0:
596
                    curr_sliding_window_block += 1
597
            else:
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
                sliding_seq_len = min(inter_data.seq_lens[seq_idx],
                                      self.sliding_window)

        inter_data.curr_sliding_window_blocks[
            seq_idx] = curr_sliding_window_block
        inter_data.seq_lens[seq_idx] = sliding_seq_len

    def _compute_lora_input(self, inter_data: InterDataForSeqGroup,
                            seq_idx: int,
                            seq_group_metadata: SequenceGroupMetadata):
        """If LoRA is enabled, compute LoRA index and prompt mapping."""
        if not self.enable_lora:
            return

        lora_id = seq_group_metadata.lora_int_id
        if lora_id > 0:
            inter_data.lora_requests.add(seq_group_metadata.lora_request)
        query_len = inter_data.query_lens[seq_idx]
        inter_data.lora_index_mapping.append([lora_id] * query_len)
        inter_data.lora_prompt_mapping.append(
            [lora_id] *
            (query_len if seq_group_metadata.sampling_params
             and seq_group_metadata.sampling_params.prompt_logprobs is not None
             else 1))

    def _compute_prompt_adapter_input(
            self, inter_data: InterDataForSeqGroup,
            seq_group_metadata: SequenceGroupMetadata):
        """If prompt adapter is enabled, compute index and prompt mapping.
        """
        # Note that when is_prompt=True, we expect only one sequence
        # in the group.
        if not self.enable_prompt_adapter:
            return

        prompt_adapter_id = seq_group_metadata.prompt_adapter_id
        if prompt_adapter_id <= 0 or not inter_data.is_prompt:
            return

        # We expect only one sequence in the group when is_prompt=True.
        assert inter_data.n_seqs == 1
        query_len = inter_data.query_lens[0]
        inter_data.prompt_adapter_request = (
            seq_group_metadata.prompt_adapter_request)

        num_tokens = seq_group_metadata.prompt_adapter_num_virtual_tokens
        inter_data.prompt_adapter_index_mapping = [
            prompt_adapter_id
        ] * num_tokens + [0] * (query_len - num_tokens)
        inter_data.prompt_adapter_prompt_mapping = [prompt_adapter_id] * (
            query_len if seq_group_metadata.sampling_params
            and seq_group_metadata.sampling_params.prompt_logprobs else 1)

    def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
                                   seq_group_metadata: SequenceGroupMetadata):
        """If multi-modal data is given, add it to the input."""
        mm_data = seq_group_metadata.multi_modal_data
        if not mm_data:
            return

        mm_kwargs = self.multi_modal_input_mapper(mm_data)
        inter_data.multi_modal_inputs = mm_kwargs
660

661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
        # special processing for mrope position deltas.
        if self.runner.model_is_mrope:
            image_grid_thw = mm_kwargs.get("image_grid_thw", None)
            video_grid_thw = mm_kwargs.get("video_grid_thw", None)
            assert image_grid_thw is not None or video_grid_thw is not None, (
                "mrope embedding type requires multi-modal input mapper "
                "returns 'image_grid_thw' or 'video_grid_thw'.")

            hf_config = self.runner.model_config.hf_config

            inter_data.mrope_input_positions = [None] * inter_data.n_seqs
            for seq_idx in range(inter_data.n_seqs):
                seq_data = seq_group_metadata.seq_data[
                    inter_data.seq_ids[seq_idx]]
                token_ids = seq_data.get_token_ids()

                mrope_input_positions, mrope_position_delta = \
                    MRotaryEmbedding.get_input_positions(
                        token_ids,
                        image_grid_thw=image_grid_thw,
                        video_grid_thw=video_grid_thw,
                        image_token_id=hf_config.image_token_id,
                        video_token_id=hf_config.video_token_id,
                        vision_start_token_id=hf_config.vision_start_token_id,
                        vision_end_token_id=hf_config.vision_end_token_id,
                        spatial_merge_size=hf_config.vision_config.
                        spatial_merge_size,
                        context_len=inter_data.context_lens[seq_idx],
                    )

                seq_data.mrope_position_delta = mrope_position_delta
                inter_data.mrope_input_positions[
                    seq_idx] = mrope_input_positions

695
    def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
696
        """Add a sequence group to the builder."""
697
        seq_ids = seq_group_metadata.seq_data.keys()
698
699
700
701
702
703
704
        n_seqs = len(seq_ids)
        is_prompt = seq_group_metadata.is_prompt

        if is_prompt:
            assert n_seqs == 1
            self.decode_only = False

705
706
        encoder_seq_len = 0

707
        if self.is_encoder_decoder_model:
708
709
            encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len()

710
        inter_data = self.init_cached_inter_data(
711
712
713
714
            request_id=seq_group_metadata.request_id,
            seq_ids=seq_ids,
            is_prompt=is_prompt,
            block_tables=seq_group_metadata.block_tables,
715
716
            computed_block_nums=seq_group_metadata.computed_block_nums,
            reinit=True,
717
718
            reinit_use_defaults=True,
            encoder_seq_len=encoder_seq_len)
719

720
        self.inter_data_list.append(inter_data)
721

722
723
724
725
726
        for seq_idx in range(n_seqs):
            for per_seq_fn in self.per_seq_compute_fns:
                per_seq_fn(inter_data, seq_idx, seq_group_metadata)
        for per_seq_group_fn in self.per_seq_group_compute_fns:
            per_seq_group_fn(inter_data, seq_group_metadata)
727

728
729
730
731
    def _use_captured_graph(self,
                            batch_size: int,
                            max_decode_seq_len: int,
                            max_encoder_seq_len: int = 0) -> bool:
732
        return (self.decode_only and not self.runner.model_config.enforce_eager
733
734
735
736
                and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
                and max_decode_seq_len <= self.runner.max_seq_len_to_capture
                and max_encoder_seq_len <= self.runner.max_seq_len_to_capture
                and batch_size <= self.runner.max_batchsize_to_capture)
737

738
    def build(self) -> ModelInputForGPU:
739
740
741
742
        """Finalize the builder intermediate data and
        create on-device tensors.
        """
        # Combine and flatten intermediate data.
743
744
745
746
747
        input_tokens = []
        for inter_data in self.inter_data_list:
            for cur_input_tokens in inter_data.input_tokens:
                input_tokens.extend(cur_input_tokens)

748
749
750
        if not input_tokens:
            # This may happen when all prefill requests hit
            # prefix caching and there is no decode request.
751
            return self.model_input_cls()
752

753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
        mrope_input_positions: Optional[List[List[int]]] = None
        if any(inter_data.mrope_input_positions is not None
               for inter_data in self.inter_data_list):
            mrope_input_positions = [[] for _ in range(3)]
            for idx in range(3):
                for inter_data in self.inter_data_list:
                    msections = inter_data.mrope_input_positions
                    if msections is None:
                        for _seq_input_positions in inter_data.input_positions:
                            mrope_input_positions[idx].extend(
                                _seq_input_positions)
                    else:
                        for _seq_mrope_input_positions in msections:
                            mrope_input_positions[idx].extend(
                                _seq_mrope_input_positions[idx])
            input_positions = None
        else:
            input_positions = []
            for inter_data in self.inter_data_list:
                for cur_input_positions in inter_data.input_positions:
                    input_positions.extend(cur_input_positions)
774

775
        seq_lens = []
776
        query_lens = []
777
        max_decode_seq_len = 0
778
        max_encoder_seq_len = 0
779
780
        for inter_data in self.inter_data_list:
            seq_lens.extend(inter_data.seq_lens)
781
            query_lens.extend(inter_data.query_lens)
782
783
784
            if not inter_data.is_prompt:
                max_decode_seq_len = max(max_decode_seq_len,
                                         max(inter_data.seq_lens))
785
                if self.is_encoder_decoder_model:
786
787
                    max_encoder_seq_len = max(max_encoder_seq_len,
                                              inter_data.encoder_seq_len)
788

789
790
791
792
793
794
        # Mapping from request IDs to sequence IDs. Used for Jamba models
        # that manages the cache by itself.
        request_ids_to_seq_ids = {
            data.request_id: data.seq_ids
            for data in self.inter_data_list
        }
795

796
        batch_size = len(input_tokens)
797
798
799
800
        use_captured_graph = self._use_captured_graph(
            batch_size,
            max_decode_seq_len,
            max_encoder_seq_len=max_encoder_seq_len)
801
802
803
804
805
806
807
808
809
810
811
812

        # If cuda graph can be used, pad tensors accordingly.
        # See `capture_model` API for more details.
        # vLLM uses cuda graph only for decoding requests.
        cuda_graph_pad_size = -1
        if use_captured_graph:
            graph_batch_size = _get_graph_batch_size(batch_size)
            assert graph_batch_size >= batch_size
            cuda_graph_pad_size = graph_batch_size - batch_size
            batch_size = graph_batch_size

        # Tokens and positions.
813
814
        if cuda_graph_pad_size:
            input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
815
816
817
818
        assert self.runner.device is not None
        input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
                                               self.runner.device,
                                               self.runner.pin_memory)
819
820
821
822
823
824
825
826
827
828
829
830
831
832
        if mrope_input_positions is not None:
            for idx in range(3):
                mrope_input_positions[idx].extend(
                    itertools.repeat(0, cuda_graph_pad_size))
            input_positions_tensor = async_tensor_h2d(mrope_input_positions,
                                                      torch.long,
                                                      self.runner.device,
                                                      self.runner.pin_memory)
        else:
            input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
            input_positions_tensor = async_tensor_h2d(input_positions,
                                                      torch.long,
                                                      self.runner.device,
                                                      self.runner.pin_memory)
833
        # Sequence and query lengths.
834
835
        if cuda_graph_pad_size:
            seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
836

837
838
        # Attention metadata.
        attn_metadata = self.attn_metadata_builder.build(
839
            seq_lens, query_lens, cuda_graph_pad_size, batch_size)
840
841

        # LoRA data.
842
843
        lora_requests = set()
        lora_mapping = None
844
        if self.enable_lora:
845
846
847
848
849
850
            lora_requests = set(r for data in self.inter_data_list
                                for r in data.lora_requests)
            lora_index_mapping = flatten_2d_lists([
                flatten_2d_lists(inter_data.lora_index_mapping)
                for inter_data in self.inter_data_list
            ])
851
852
853
            if cuda_graph_pad_size:
                lora_index_mapping.extend(
                    itertools.repeat(0, cuda_graph_pad_size))
854
855
856
857
            lora_prompt_mapping = flatten_2d_lists([
                flatten_2d_lists(inter_data.lora_prompt_mapping)
                for inter_data in self.inter_data_list
            ])
858

859
            lora_mapping = LoRAMapping(
860
861
862
                **dict(index_mapping=lora_index_mapping,
                       prompt_mapping=lora_prompt_mapping,
                       is_prefill=not self.decode_only))
863
864

        # Prompt adapter data.
865
866
        prompt_adapter_requests: Set[PromptAdapterRequest] = set()
        prompt_adapter_mapping = None
867
        if self.enable_prompt_adapter:
868
869
870
871
872
873
874
            prompt_adapter_requests = set(
                data.prompt_adapter_request for data in self.inter_data_list
                if data.prompt_adapter_request is not None)
            prompt_adapter_index_mapping = flatten_2d_lists([
                inter_data.prompt_adapter_index_mapping
                for inter_data in self.inter_data_list
            ])
875
876
877
            if cuda_graph_pad_size:
                prompt_adapter_index_mapping.extend(
                    itertools.repeat(0, cuda_graph_pad_size))
878
879
880
881
            prompt_adapter_prompt_mapping = flatten_2d_lists([
                inter_data.prompt_adapter_prompt_mapping
                for inter_data in self.inter_data_list
            ])
882
            prompt_adapter_mapping = PromptAdapterMapping(
883
884
                prompt_adapter_index_mapping,
                prompt_adapter_prompt_mapping,
885
886
887
            )

        # Multi-modal data.
888
889
890
891
        multi_modal_inputs_list = [
            data.multi_modal_inputs for data in self.inter_data_list
            if data.multi_modal_inputs is not None
        ]
892
        multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
893
894
895
896
897

        return self.model_input_cls(
            input_tokens=input_tokens_tensor,
            input_positions=input_positions_tensor,
            attn_metadata=attn_metadata,
898
899
            seq_lens=seq_lens,
            query_lens=query_lens,
900
            lora_mapping=lora_mapping,
901
            lora_requests=lora_requests,
902
            multi_modal_kwargs=multi_modal_kwargs,
903
            request_ids_to_seq_ids=request_ids_to_seq_ids,
904
905
            finished_requests_ids=self.finished_requests_ids,
            prompt_adapter_mapping=prompt_adapter_mapping,
906
            prompt_adapter_requests=prompt_adapter_requests)
907
908


909
910
911
912
913
class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
    """
    Helper class for shared methods between GPU model runners.
    """
    _model_input_cls: Type[TModelInputForGPU]
914
    _builder_cls: Type[ModelInputForGPUBuilder]
915
916
917
918
919
920

    def __init__(
        self,
        model_config: ModelConfig,
        parallel_config: ParallelConfig,
        scheduler_config: SchedulerConfig,
921
        device_config: DeviceConfig,
922
        cache_config: CacheConfig,
923
        load_config: LoadConfig,
924
        lora_config: Optional[LoRAConfig],
925
        kv_cache_dtype: Optional[str] = "auto",
926
        is_driver_worker: bool = False,
927
        prompt_adapter_config: Optional[PromptAdapterConfig] = None,
928
        return_hidden_states: bool = False,
929
        observability_config: Optional[ObservabilityConfig] = None,
930
931
        input_registry: InputRegistry = INPUT_REGISTRY,
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
932
933
934
935
    ):
        self.model_config = model_config
        self.parallel_config = parallel_config
        self.scheduler_config = scheduler_config
936
937
        self.device_config = device_config
        self.cache_config = cache_config
938
        self.lora_config = lora_config
939
        self.load_config = load_config
940
        self.is_driver_worker = is_driver_worker
941
        self.prompt_adapter_config = prompt_adapter_config
942
        self.return_hidden_states = return_hidden_states
943
        self.observability_config = observability_config
944

945
        self.device = self.device_config.device
946
        self.pin_memory = is_pin_memory_available()
947

948
949
950
951
        self.kv_cache_dtype = kv_cache_dtype
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
952
953
        self.max_batchsize_to_capture = _get_max_graph_batch_size(
            self.scheduler_config.max_num_seqs)
954
955
956
957

        self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
            {} for _ in range(self.parallel_config.pipeline_parallel_size)
        ]
958
959
        self.graph_memory_pool: Optional[Tuple[
            int, int]] = None  # Set during graph capture.
Mor Zusman's avatar
Mor Zusman committed
960
961
962
963

        self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers(
            parallel_config)

964
        # When using CUDA graph, the input block tables must be padded to
965
        # max_seq_len_to_capture. However, creating the block table in
966
967
968
969
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
        # (max batch size to capture, max context len to capture / block size).
970
        self.graph_block_tables = np.zeros(
971
            (self.max_batchsize_to_capture, self.get_max_block_per_batch()),
972
            dtype=np.int32)
973
974
        num_attn_heads = self.model_config.get_num_attention_heads(
            self.parallel_config)
975
        self.attn_backend = get_attn_backend(
976
            num_attn_heads,
977
978
979
980
981
982
            self.model_config.get_head_size(),
            self.model_config.get_num_kv_heads(self.parallel_config),
            self.model_config.get_sliding_window(),
            self.model_config.dtype,
            self.kv_cache_dtype,
            self.block_size,
983
        ) if num_attn_heads else None
984
985
986
987
988
        if self.attn_backend:
            self.attn_state = self.attn_backend.get_state_cls()(
                weakref.proxy(self))
        else:
            self.attn_state = CommonAttentionState(weakref.proxy(self))
989

990
        # Multi-modal data support
991
992
993
994
        self.input_registry = input_registry
        self.mm_registry = mm_registry
        self.multi_modal_input_mapper = mm_registry \
            .create_input_mapper(model_config)
995
        self.mm_registry.init_mm_limits_per_prompt(self.model_config)
996

997
        # Lazy initialization
998
        self.model: nn.Module  # Set after load_model
999
1000
        # Set after load_model.
        self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
1001
        self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
1002

1003
1004
1005
        set_cpu_offload_max_bytes(
            int(self.cache_config.cpu_offload_gb * 1024**3))

1006
1007
1008
1009
1010
        # Used to cache python objects
        self.inter_data_cache: Dict[int, PyObjectCache] = {}
        self.sampling_metadata_cache: SamplingMetadataCache = \
            SamplingMetadataCache()

1011
    def load_model(self) -> None:
1012
        logger.info("Starting to load model %s...", self.model_config.model)
1013
        with DeviceMemoryProfiler() as m:
1014
1015
1016
1017
1018
1019
1020
            self.model = get_model(model_config=self.model_config,
                                   device_config=self.device_config,
                                   load_config=self.load_config,
                                   lora_config=self.lora_config,
                                   parallel_config=self.parallel_config,
                                   scheduler_config=self.scheduler_config,
                                   cache_config=self.cache_config)
1021
1022

        self.model_memory_usage = m.consumed_memory
1023
1024
        logger.info("Loading model weights took %.4f GB",
                    self.model_memory_usage / float(2**30))
1025
1026

        if self.lora_config:
1027
            assert supports_lora(self.model), "Model does not support LoRA"
1028
            assert not supports_multimodal(
1029
                self.model
1030
            ), "To be tested: Multi-modal model with LoRA settings."
1031

1032
1033
            self.lora_manager = LRUCacheWorkerLoRAManager(
                self.scheduler_config.max_num_seqs,
1034
1035
1036
1037
1038
1039
1040
1041
1042
                self.scheduler_config.max_num_batched_tokens,
                self.vocab_size,
                self.lora_config,
                self.device,
                self.model.embedding_modules,
                self.model.embedding_padding_modules,
                max_position_embeddings=self.model.config.
                max_position_embeddings,
            )
1043
            self.model = self.lora_manager.create_lora_manager(self.model)
1044

1045
1046
1047
1048
1049
1050
1051
1052
1053
        if self.prompt_adapter_config:
            self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens, self.device,
                self.prompt_adapter_config)
            self.model = (
                self.prompt_adapter_manager.create_prompt_adapter_manager(
                    self.model))

1054
        if self.kv_cache_dtype == "fp8" and is_hip():
1055
1056
1057
            # Currently only ROCm accepts kv-cache scaling factors
            # via quantization_param_path and this will be deprecated
            # in the future.
1058
1059
            if self.model_config.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
1060
1061
1062
1063
1064
1065
                    warnings.warn(
                        "Loading kv cache scaling factor from JSON is "
                        "deprecated and will be removed. Please include "
                        "kv cache scaling factors in the model checkpoint.",
                        FutureWarning,
                        stacklevel=2)
1066
1067
                    self.model.load_kv_cache_scales(
                        self.model_config.quantization_param_path)
1068
1069
                    logger.info("Loaded KV cache scaling factors from %s",
                                self.model_config.quantization_param_path)
1070
                else:
1071
1072
1073
1074
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__)
1075
            else:
1076
1077
1078
1079
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!")
1080

1081
        if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
1082
            from vllm.compilation.backends import vllm_backend
1083
            from vllm.plugins import get_torch_compile_backend
1084
            backend = get_torch_compile_backend() or vllm_backend
1085
1086
1087
            self.model = torch.compile(
                self.model,
                fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
1088
                backend=backend)
1089

1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        from vllm.model_executor.model_loader.loader import ShardedStateLoader
        ShardedStateLoader.save_model(
            self.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
    def save_tensorized_model(
        self,
        tensorizer_config: TensorizerConfig,
    ) -> None:
        from vllm.model_executor.model_loader.loader import TensorizerLoader
        TensorizerLoader.save_model(
            self.model,
            tensorizer_config=tensorizer_config,
        )

1114
1115
    def get_max_block_per_batch(self) -> int:
        block_size = self.block_size
1116
        return (self.max_seq_len_to_capture + block_size - 1) // block_size
1117

1118
    def _prepare_model_input_tensors(
1119
1120
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
Mor Zusman's avatar
Mor Zusman committed
1121
        finished_requests_ids: Optional[List[str]] = None
1122
1123
1124
1125
    ) -> TModelInputForGPU:
        """Helper method to prepare the model input based on a given sequence
        group. Prepares metadata needed for the base model forward pass but not
        metadata for possible additional steps, e.g., sampling.
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136

        The API assumes seq_group_metadata_list is sorted by prefill -> decode.

        The result tensors and data structure also batches input in prefill
        -> decode order. For example,

        - input_tokens[:num_prefill_tokens] contains prefill tokens.
        - input_tokens[num_prefill_tokens:] contains decode tokens.

        If cuda graph is required, this API automatically pads inputs.
        """
1137
        builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
1138
        for seq_group_metadata in seq_group_metadata_list:
1139
            builder.add_seq_group(seq_group_metadata)
1140
1141
1142

        builder.reset_cached_inter_data()

1143
        return builder.build()  # type: ignore
1144

1145
1146
1147
    @torch.inference_mode()
    def profile_run(self) -> None:
        # Enable top-k sampling to reflect the accurate memory usage.
1148
        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
1149
1150
        max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs
1151
1152
1153
1154
        # This represents the maximum number of different requests
        # that will have unique loras, an therefore the max amount of memory
        # consumption create dummy lora request copies from the lora request
        # passed in, which contains a lora from the lora warmup path.
1155
1156
        dummy_lora_requests: List[LoRARequest] = []
        dummy_lora_requests_per_seq: List[LoRARequest] = []
1157
        if self.lora_config:
1158
            assert self.lora_manager is not None
1159
1160
1161
1162
1163
1164
            with self.lora_manager.dummy_lora_cache():
                for idx in range(self.lora_config.max_loras):
                    lora_id = idx + 1
                    dummy_lora_request = LoRARequest(
                        lora_name=f"warmup_{lora_id}",
                        lora_int_id=lora_id,
1165
                        lora_path="/not/a/real/path",
1166
1167
1168
1169
1170
1171
1172
1173
                    )
                    self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                     rank=LORA_WARMUP_RANK)
                    dummy_lora_requests.append(dummy_lora_request)
                dummy_lora_requests_per_seq = [
                    dummy_lora_requests[idx % len(dummy_lora_requests)]
                    for idx in range(max_num_seqs)
                ]
1174

1175
1176
1177
        # Profile memory usage with max_num_sequences sequences and the total
        # number of tokens equal to max_num_batched_tokens.
        seqs: List[SequenceGroupMetadata] = []
1178
1179
        # Additional GPU memory may be needed for multi-modal encoding, which
        # needs to be accounted for when calculating the GPU blocks for
1180
1181
1182
1183
        # vLLM blocker manager.
        # To exercise the worst scenario for GPU memory consumption,
        # the number of seqs (batch_size) is chosen to maximize the number
        # of images processed.
1184

1185
1186
        max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
            self.model_config)
1187
        if max_mm_tokens > 0:
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
            max_num_seqs_orig = max_num_seqs
            max_num_seqs = min(max_num_seqs,
                               max_num_batched_tokens // max_mm_tokens)
            if max_num_seqs < 1:
                expr = (f"min({max_num_seqs_orig}, "
                        f"{max_num_batched_tokens} // {max_mm_tokens})")
                logger.warning(
                    "Computed max_num_seqs (%s) to be less than 1. "
                    "Setting it to the minimum value of 1.", expr)
                max_num_seqs = 1

1199
        batch_size = 0
1200
1201
1202
        for group_id in range(max_num_seqs):
            seq_len = (max_num_batched_tokens // max_num_seqs +
                       (group_id < max_num_batched_tokens % max_num_seqs))
1203
            batch_size += seq_len
1204

1205
1206
1207
1208
            seq_data, dummy_multi_modal_data = self.input_registry \
                .dummy_data_for_profiling(self.model_config,
                                          seq_len,
                                          self.mm_registry)
1209

1210
1211
1212
1213
1214
1215
            seq = SequenceGroupMetadata(
                request_id=str(group_id),
                is_prompt=True,
                seq_data={group_id: seq_data},
                sampling_params=sampling_params,
                block_tables=None,
1216
1217
                lora_request=dummy_lora_requests_per_seq[group_id]
                if dummy_lora_requests_per_seq else None,
1218
                multi_modal_data=dummy_multi_modal_data,
1219
1220
1221
1222
1223
            )
            seqs.append(seq)

        # Run the model with the dummy inputs.
        num_layers = self.model_config.get_num_layers(self.parallel_config)
1224
        kv_caches = [None] * num_layers
Mor Zusman's avatar
Mor Zusman committed
1225
1226
1227
        finished_requests_ids = [seq.request_id for seq in seqs]
        model_input = self.prepare_model_input(
            seqs, finished_requests_ids=finished_requests_ids)
1228
1229
1230
1231
1232
1233
1234
        intermediate_tensors = None
        if not get_pp_group().is_first_rank:
            intermediate_tensors = self.model.make_empty_intermediate_tensors(
                batch_size=batch_size,
                dtype=self.model_config.dtype,
                device=self.device)
        self.execute_model(model_input, kv_caches, intermediate_tensors)
1235
        torch.cuda.synchronize()
1236
1237
        return

1238
    def remove_all_loras(self):
1239
1240
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
1241
        self.lora_manager.remove_all_adapters()
1242

1243
    def set_active_loras(self, lora_requests: Set[LoRARequest],
1244
1245
1246
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
1247
        self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
1248
1249
1250
1251

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
1252
        return self.lora_manager.add_adapter(lora_request)
1253
1254
1255
1256

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
1257
        return self.lora_manager.remove_adapter(lora_id)
1258
1259
1260
1261

    def pin_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
1262
        return self.lora_manager.pin_adapter(lora_id)
1263
1264
1265
1266

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
        return self.lora_manager.list_adapters()

    def remove_all_prompt_adapters(self):
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        self.prompt_adapter_manager.remove_all_adapters()

    def set_active_prompt_adapters(
            self, prompt_adapter_requests: Set[PromptAdapterRequest],
            prompt_adapter_mapping: PromptAdapterMapping) -> None:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        self.prompt_adapter_manager.set_active_adapters(
            prompt_adapter_requests, prompt_adapter_mapping)

    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        return self.prompt_adapter_manager.add_adapter(prompt_adapter_request)

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id)

    def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id)

    def list_prompt_adapters(self) -> Set[int]:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        return self.prompt_adapter_manager.list_adapters()
1302

1303
1304
1305
1306
1307
1308
1309
1310
1311
    @property
    def model_is_mrope(self) -> bool:
        """Detect if the model has "mrope" rope_scaling type.
        mrope requires keep "rope_deltas" between prompt and decoding phases."""
        rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
        if rope_scaling is None:
            return False
        return rope_scaling.get("type", None) == "mrope"

1312
    @torch.inference_mode()
1313
    def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
        """Cuda graph capture a model.

        Note that CUDA graph's performance gain is negligible if number
        of batched tokens are larger than 200. And since CUDA graph
        requires fixed sized tensors, supporting large/variable batch
        size requires high GPU memory overhead. Thus, vLLM only captures
        decoding requests. Mixed batch (chunked prefill + decoding) or
        prefill requests are not captured.

        Since it is used for decoding-only, it assumes there's only 1 token
        per sequence in the batch.
        """
1326
1327
1328
1329
1330
        assert not self.model_config.enforce_eager
        logger.info("Capturing the model for CUDA graphs. This may lead to "
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
                    "use '--enforce-eager' in the CLI.")
1331
1332
        logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
                    "If you are running out of memory, consider decreasing "
1333
1334
1335
                    "`gpu_memory_utilization` or enforcing eager mode. "
                    "You can also reduce the `max_num_seqs` as needed "
                    "to decrease memory usage.")
1336
1337
1338
        start_time = time.perf_counter()

        # Prepare dummy inputs. These will be reused for all batch sizes.
1339
        max_batch_size = self.max_batchsize_to_capture
1340
1341
        input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
        input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
1342
1343
        if self.model_is_mrope:
            input_positions = torch.tile(input_positions, (3, 1))
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
        # Prepare dummy previous_hidden_states only if needed by the model.
        # This is used by draft models such as EAGLE.
        previous_hidden_states = None
        if "previous_hidden_states" in inspect.signature(
                self.model.forward).parameters:
            previous_hidden_states = torch.empty(
                [max_batch_size,
                 self.model_config.get_hidden_size()],
                dtype=self.model_config.dtype,
                device=self.device)

1355
1356
1357
1358
1359
1360
        intermediate_inputs = None
        if not get_pp_group().is_first_rank:
            intermediate_inputs = self.model.make_empty_intermediate_tensors(
                batch_size=max_batch_size,
                dtype=self.model_config.dtype,
                device=self.device)
1361

1362
1363
        # Prepare buffer for outputs. These will be reused for all batch sizes.
        # It will be filled after the first graph capture.
1364
1365
1366
        hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [
            None
        ] * self.parallel_config.pipeline_parallel_size
1367

1368
        graph_batch_size = self.max_batchsize_to_capture
1369
1370
1371
1372
        batch_size_capture_list = [
            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
        ]

1373
1374
        with self.attn_state.graph_capture(
                max_batch_size), graph_capture() as graph_capture_context:
1375
1376
            # NOTE: Capturing the largest batch size first may help reduce the
            # memory usage of CUDA graph.
1377
1378
1379
            for virtual_engine in range(
                    self.parallel_config.pipeline_parallel_size):
                for batch_size in reversed(batch_size_capture_list):
1380
1381
                    attn_metadata = (
                        self.attn_state.graph_capture_get_metadata_for_batch(
1382
1383
1384
                            batch_size,
                            is_encoder_decoder_model=self.model_config.
                            is_encoder_decoder_model))
1385
1386
1387

                    if self.lora_config:
                        lora_mapping = LoRAMapping(
1388
1389
1390
                            **dict(index_mapping=[0] * batch_size,
                                   prompt_mapping=[0] * batch_size,
                                   is_prefill=False))
1391
1392
                        self.set_active_loras(set(), lora_mapping)

1393
1394
1395
1396
1397
1398
1399
                    if self.prompt_adapter_config:
                        prompt_adapter_mapping = PromptAdapterMapping(
                            [-1] * batch_size,
                            [-1] * batch_size,
                        )
                        self.set_active_prompt_adapters(
                            set(), prompt_adapter_mapping)
1400
                    graph_runner = CUDAGraphRunner(
1401
                        self.model, self.attn_backend.get_name(),
1402
1403
                        self.attn_state.graph_clone(batch_size),
                        self.model_config.is_encoder_decoder_model)
1404

Mor Zusman's avatar
Mor Zusman committed
1405
1406
                    capture_inputs = {
                        "input_ids":
1407
                        input_tokens[:batch_size],
Mor Zusman's avatar
Mor Zusman committed
1408
                        "positions":
1409
                        input_positions[..., :batch_size],
Mor Zusman's avatar
Mor Zusman committed
1410
                        "hidden_or_intermediate_states":
1411
1412
1413
1414
1415
                        hidden_or_intermediate_states[
                            virtual_engine]  # type: ignore
                        [:batch_size]
                        if hidden_or_intermediate_states[virtual_engine]
                        is not None else None,
Mor Zusman's avatar
Mor Zusman committed
1416
                        "intermediate_inputs":
1417
1418
                        intermediate_inputs[:batch_size]
                        if intermediate_inputs is not None else None,
Mor Zusman's avatar
Mor Zusman committed
1419
                        "kv_caches":
1420
                        kv_caches[virtual_engine],
Mor Zusman's avatar
Mor Zusman committed
1421
                        "attn_metadata":
1422
                        attn_metadata,
Mor Zusman's avatar
Mor Zusman committed
1423
1424
1425
1426
1427
                        "memory_pool":
                        self.graph_memory_pool,
                        "stream":
                        graph_capture_context.stream
                    }
1428
1429
1430
1431
1432
                    if previous_hidden_states is not None:
                        capture_inputs[
                            "previous_hidden_states"] = previous_hidden_states[:
                                                                               batch_size]

Mor Zusman's avatar
Mor Zusman committed
1433
1434
1435
1436
1437
1438
1439
                    if self.has_seqlen_agnostic:
                        # Only used by Mamba-based models CUDA graph atm (Jamba)
                        capture_inputs.update({
                            "seqlen_agnostic_capture_inputs":
                            self.model.get_seqlen_agnostic_capture_inputs(
                                batch_size)
                        })
1440
1441
1442
1443
1444
1445
                    if self.model_config.is_encoder_decoder_model:
                        # add the additional inputs to capture for
                        # encoder-decoder models.
                        self._update_inputs_to_capture_for_enc_dec_model(
                            capture_inputs)

Mor Zusman's avatar
Mor Zusman committed
1446
                    graph_runner.capture(**capture_inputs)
1447
1448
1449
                    self.graph_memory_pool = graph_runner.graph.pool()
                    self.graph_runners[virtual_engine][batch_size] = (
                        graph_runner)
1450
1451
1452
1453

        end_time = time.perf_counter()
        elapsed_time = end_time - start_time
        # This usually takes < 10 seconds.
1454
        logger.info("Graph capturing finished in %.0f secs.", elapsed_time)
1455

1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
    def _update_inputs_to_capture_for_enc_dec_model(self,
                                                    capture_inputs: Dict[str,
                                                                         Any]):
        """
        Updates the set of input tensors needed for CUDA graph capture in an
        encoder-decoder model.

        This method modifies the provided `capture_inputs` dictionary by
        adding tensors specific to encoder-decoder specific models that
        need to be captured for CUDA Graph replay.
        """
        # During the decode phase encoder_input_ids and encoder_positions are
        # unset. Do the same thing for graph capture.
        capture_inputs["encoder_input_ids"] = torch.tensor(
            [], dtype=torch.long).cuda()
        capture_inputs["encoder_positions"] = torch.tensor(
            [], dtype=torch.long).cuda()

1474
1475
1476
1477
    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

1478

1479
1480
1481
1482
1483
1484
class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
    """
    GPU model runner with sampling step.
    """
    _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
        ModelInputForGPUWithSamplingMetadata)
1485
    _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
1486
1487
1488
1489
1490

    def make_model_input_from_broadcasted_tensor_dict(
        self,
        tensor_dict: Dict[str, Any],
    ) -> ModelInputForGPUWithSamplingMetadata:
1491
        model_input = \
1492
1493
1494
            ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
                tensor_dict,
                attn_backend=self.attn_backend,
1495
1496
            )
        return model_input
1497
1498
1499
1500

    def prepare_model_input(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
1501
        virtual_engine: int = 0,
1502
        finished_requests_ids: Optional[List[str]] = None,
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
    ) -> ModelInputForGPUWithSamplingMetadata:
        """Prepare the model input based on a given sequence group, including
        metadata for the sampling step.

        The API assumes seq_group_metadata_list is sorted by prefill -> decode.

        The result tensors and data structure also batches input in prefill
        -> decode order. For example,

        - input_tokens[:num_prefill_tokens] contains prefill tokens.
        - input_tokens[num_prefill_tokens:] contains decode tokens.

        If cuda graph is required, this API automatically pads inputs.
        """
        model_input = self._prepare_model_input_tensors(
Mor Zusman's avatar
Mor Zusman committed
1518
            seq_group_metadata_list, finished_requests_ids)
1519
1520
1521
1522
1523
1524
        if get_pp_group().is_last_rank:
            # Sampling metadata is only required for the final pp group
            generators = self.get_generators(finished_requests_ids)
            sampling_metadata = SamplingMetadata.prepare(
                seq_group_metadata_list, model_input.seq_lens,
                model_input.query_lens, self.device, self.pin_memory,
1525
                generators, self.sampling_metadata_cache)
1526
1527
        else:
            sampling_metadata = None
1528
1529
1530
1531
        is_prompt = (seq_group_metadata_list[0].is_prompt
                     if seq_group_metadata_list else None)
        return dataclasses.replace(model_input,
                                   sampling_metadata=sampling_metadata,
1532
1533
                                   is_prompt=is_prompt,
                                   virtual_engine=virtual_engine)
1534
1535

    @torch.inference_mode()
1536
    @dump_input_when_exception(exclude_args=[0], exclude_kwargs=["self"])
1537
1538
1539
1540
    def execute_model(
        self,
        model_input: ModelInputForGPUWithSamplingMetadata,
        kv_caches: List[torch.Tensor],
1541
        intermediate_tensors: Optional[IntermediateTensors] = None,
1542
        num_steps: int = 1,
1543
    ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
1544
1545
1546
        if num_steps > 1:
            raise ValueError("num_steps > 1 is not supported in ModelRunner")

1547
1548
1549
1550
1551
1552
        if self.lora_config:
            assert model_input.lora_requests is not None
            assert model_input.lora_mapping is not None
            self.set_active_loras(model_input.lora_requests,
                                  model_input.lora_mapping)

1553
1554
1555
1556
1557
1558
1559
        if self.prompt_adapter_config:
            assert model_input.prompt_adapter_requests is not None
            assert model_input.prompt_adapter_mapping is not None
            self.set_active_prompt_adapters(
                model_input.prompt_adapter_requests,
                model_input.prompt_adapter_mapping)

1560
        self.attn_state.begin_forward(model_input)
1561

1562
1563
1564
1565
        # Currently cuda graph is only supported by the decode phase.
        assert model_input.attn_metadata is not None
        prefill_meta = model_input.attn_metadata.prefill_metadata
        decode_meta = model_input.attn_metadata.decode_metadata
1566
1567
1568
        # TODO(andoorve): We can remove this once all
        # virtual engines share the same kv cache.
        virtual_engine = model_input.virtual_engine
1569
1570
1571
        if prefill_meta is None and decode_meta.use_cuda_graph:
            assert model_input.input_tokens is not None
            graph_batch_size = model_input.input_tokens.shape[0]
1572
1573
            model_executable = self.graph_runners[virtual_engine][
                graph_batch_size]
1574
1575
1576
1577
        else:
            model_executable = self.model

        multi_modal_kwargs = model_input.multi_modal_kwargs or {}
Mor Zusman's avatar
Mor Zusman committed
1578
1579
1580
1581
        seqlen_agnostic_kwargs = {
            "finished_requests_ids": model_input.finished_requests_ids,
            "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
        } if self.has_seqlen_agnostic else {}
1582
1583
1584
1585
1586
1587
        if (self.observability_config is not None
                and self.observability_config.collect_model_forward_time):
            model_forward_start = torch.cuda.Event(enable_timing=True)
            model_forward_end = torch.cuda.Event(enable_timing=True)
            model_forward_start.record()

1588
        hidden_or_intermediate_states = model_executable(
1589
1590
1591
1592
            input_ids=model_input.input_tokens,
            positions=model_input.input_positions,
            kv_caches=kv_caches,
            attn_metadata=model_input.attn_metadata,
1593
            intermediate_tensors=intermediate_tensors,
1594
1595
            **MultiModalInputs.as_kwargs(multi_modal_kwargs,
                                         device=self.device),
Mor Zusman's avatar
Mor Zusman committed
1596
            **seqlen_agnostic_kwargs)
1597

1598
1599
1600
1601
        if (self.observability_config is not None
                and self.observability_config.collect_model_forward_time):
            model_forward_end.record()

1602
1603
        # Compute the logits in the last pipeline stage.
        if not get_pp_group().is_last_rank:
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
            if (self.is_driver_worker
                    and hidden_or_intermediate_states is not None
                    and isinstance(hidden_or_intermediate_states,
                                   IntermediateTensors)
                    and self.observability_config is not None
                    and self.observability_config.collect_model_forward_time):
                model_forward_end.synchronize()
                model_forward_time = model_forward_start.elapsed_time(
                    model_forward_end)
                orig_model_forward_time = 0.0
                if intermediate_tensors is not None:
                    orig_model_forward_time = intermediate_tensors.tensors.get(
                        "model_forward_time", torch.tensor(0.0)).item()
                hidden_or_intermediate_states.tensors["model_forward_time"] = (
                    torch.tensor(model_forward_time + orig_model_forward_time))
1619
1620
1621
            return hidden_or_intermediate_states

        logits = self.model.compute_logits(hidden_or_intermediate_states,
1622
1623
1624
                                           model_input.sampling_metadata)

        if not self.is_driver_worker:
1625
            return []
1626

1627
1628
        if model_input.async_callback is not None:
            model_input.async_callback()
1629

1630
1631
1632
1633
1634
        # Sample the next token.
        output: SamplerOutput = self.model.sample(
            logits=logits,
            sampling_metadata=model_input.sampling_metadata,
        )
1635
1636
1637
1638
1639
1640
        if (self.observability_config is not None
                and self.observability_config.collect_model_forward_time
                and output is not None):
            model_forward_end.synchronize()
            model_forward_time = model_forward_start.elapsed_time(
                model_forward_end)
1641
1642
1643
1644
            orig_model_forward_time = 0.0
            if intermediate_tensors is not None:
                orig_model_forward_time = intermediate_tensors.tensors.get(
                    "model_forward_time", torch.tensor(0.0)).item()
1645
1646
1647
1648
            # If there are multiple workers, we are still tracking the latency
            # from the start time of the driver worker to the end time of the
            # driver worker. The model forward time will then end up covering
            # the communication time as well.
1649
1650
            output.model_forward_time = (orig_model_forward_time +
                                         model_forward_time)
1651
1652
1653

        if self.return_hidden_states:
            # we only need to pass hidden states of most recent token
1654
1655
            assert model_input.sampling_metadata is not None
            indices = model_input.sampling_metadata.selected_token_indices
1656
            if model_input.is_prompt:
1657
1658
                hidden_states = hidden_or_intermediate_states.index_select(
                    0, indices)
1659
                output.prefill_hidden_states = hidden_or_intermediate_states
1660
            elif decode_meta.use_cuda_graph:
1661
1662
1663
                hidden_states = hidden_or_intermediate_states[:len(indices)]
            else:
                hidden_states = hidden_or_intermediate_states
1664

1665
1666
            output.hidden_states = hidden_states

1667
        return [output]
1668
1669


1670
1671
class CUDAGraphRunner:

1672
    def __init__(self, model: nn.Module, backend_name: str,
1673
                 attn_state: AttentionState, is_encoder_decoder_model: bool):
1674
        self.model = model
1675
        self.backend_name = backend_name
1676
        self.attn_state = attn_state
1677

1678
1679
1680
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

1681
        self._graph: Optional[torch.cuda.CUDAGraph] = None
1682
        self._is_encoder_decoder_model = is_encoder_decoder_model
1683
1684
1685
1686
1687
1688

    @property
    def graph(self):
        assert self._graph is not None
        return self._graph

1689
1690
1691
1692
    def capture(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1693
1694
1695
        hidden_or_intermediate_states: Optional[Union[IntermediateTensors,
                                                      torch.Tensor]],
        intermediate_inputs: Optional[IntermediateTensors],
1696
1697
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1698
1699
        memory_pool: Optional[Tuple[int, int]],
        stream: torch.cuda.Stream,
1700
        **kwargs,
1701
    ) -> Union[torch.Tensor, IntermediateTensors]:
1702
        assert self._graph is None
1703
        # Run the model a few times without capturing the graph.
1704
1705
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
1706
1707
1708
        # Note one iteration is not enough for torch.jit.script
        for _ in range(_NUM_WARMUP_ITERS):
            self.model(
1709
1710
1711
1712
1713
                input_ids=input_ids,
                positions=positions,
                kv_caches=kv_caches,
                attn_metadata=attn_metadata,
                intermediate_tensors=intermediate_inputs,
1714
1715
                **kwargs,
            )
1716
1717
        # Wait for the warm up operations to finish before proceeding with
        # Graph Capture.
1718
1719
1720
1721
        torch.cuda.synchronize()
        # Capture the graph.
        self._graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
1722
            output_hidden_or_intermediate_states = self.model(
1723
1724
1725
1726
1727
                input_ids=input_ids,
                positions=positions,
                kv_caches=kv_caches,
                attn_metadata=attn_metadata,
                intermediate_tensors=intermediate_inputs,
1728
                **kwargs,
1729
            )
1730
1731
1732
1733
1734
1735
1736
1737
            if hidden_or_intermediate_states is not None:
                if get_pp_group().is_last_rank:
                    hidden_or_intermediate_states.copy_(
                        output_hidden_or_intermediate_states)
                else:
                    for key in hidden_or_intermediate_states.tensors:
                        hidden_or_intermediate_states[key].copy_(
                            output_hidden_or_intermediate_states[key])
1738
            else:
1739
1740
1741
1742
                hidden_or_intermediate_states = (
                    output_hidden_or_intermediate_states)

            del output_hidden_or_intermediate_states
1743
1744
1745
            # make sure `output_hidden_states` is deleted
            # in the graph's memory pool
            gc.collect()
1746
1747
1748
        torch.cuda.synchronize()

        # Save the input and output buffers.
1749
        self.input_buffers = {
1750
1751
1752
1753
1754
1755
1756
1757
            "input_ids":
            input_ids,
            "positions":
            positions,
            "kv_caches":
            kv_caches,
            **self.attn_state.get_graph_input_buffers(
                attn_metadata, self._is_encoder_decoder_model),
1758
1759
            **kwargs,
        }
1760
1761
1762
1763
1764
1765
1766
1767
1768
        if intermediate_inputs is not None:
            self.input_buffers.update(intermediate_inputs.tensors)
        if get_pp_group().is_last_rank:
            self.output_buffers = {
                "hidden_states": hidden_or_intermediate_states
            }
        else:
            self.output_buffers = hidden_or_intermediate_states
        return hidden_or_intermediate_states
1769
1770
1771
1772
1773

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1774
1775
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
1776
        intermediate_tensors: Optional[IntermediateTensors],
1777
        **kwargs,
1778
1779
1780
1781
1782
    ) -> torch.Tensor:
        # KV caches are fixed tensors, so we don't need to copy them.
        del kv_caches

        # Copy the input tensors to the input buffers.
1783
1784
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
        self.input_buffers["positions"].copy_(positions, non_blocking=True)
1785
        self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
1786
                                                 non_blocking=True)
1787
1788
        self.attn_state.prepare_graph_input_buffers(
            self.input_buffers, attn_metadata, self._is_encoder_decoder_model)
Mor Zusman's avatar
Mor Zusman committed
1789
1790
1791
        if "seqlen_agnostic_capture_inputs" in self.input_buffers:
            self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
                                                      **kwargs)
1792
1793
1794
1795
1796

        if "previous_hidden_states" in self.input_buffers:
            self.input_buffers["previous_hidden_states"].copy_(
                kwargs["previous_hidden_states"], non_blocking=True)

1797
1798
        if intermediate_tensors is not None:
            for key in intermediate_tensors.tensors:
1799
                if key != "model_execute_time" and key != "model_forward_time":
1800
1801
                    self.input_buffers[key].copy_(intermediate_tensors[key],
                                                  non_blocking=True)
1802
1803
1804
1805
1806
1807
        if self._is_encoder_decoder_model:
            self.input_buffers["encoder_input_ids"].copy_(
                kwargs['encoder_input_ids'], non_blocking=True)
            self.input_buffers["encoder_positions"].copy_(
                kwargs['encoder_positions'], non_blocking=True)

1808
1809
1810
        # Run the graph.
        self.graph.replay()
        # Return the output tensor.
1811
1812
1813
1814
        if get_pp_group().is_last_rank:
            return self.output_buffers["hidden_states"]

        return self.output_buffers
1815
1816
1817
1818

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

1819

1820
def _get_graph_batch_size(batch_size: int) -> int:
1821
1822
1823
1824
1825
    """Returns the padded batch size given actual batch size.

    Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
    2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
    """
1826
1827
1828
1829
1830
    if batch_size <= 2:
        return batch_size
    elif batch_size <= 4:
        return 4
    else:
1831
1832
        return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
                _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851


def _get_max_graph_batch_size(max_num_seqs: int) -> int:
    """
    max_num_seqs: Maximum number of sequences in a batch.
    _BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.

    pad the max_num_seqs if necessary by calling _get_graph_batch_size,
    which will deal with some edge cases like 1, 2, 4.

    if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size.
    if not, it means the padded size is larger than the largest size in
    _BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE.
    """
    padded_size = _get_graph_batch_size(max_num_seqs)
    if padded_size in _BATCH_SIZES_TO_CAPTURE:
        return padded_size
    assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
    return _BATCH_SIZES_TO_CAPTURE[-1]