model_runner.py 95.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import dataclasses
4
import gc
5
import inspect
6
import itertools
7
import time
8
import weakref
9
from contextlib import contextmanager
10
from dataclasses import dataclass
11
12
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set,
                    Tuple, Type, TypeVar, Union)
13

14
import numpy as np
15
import torch
16
import torch.distributed
17
import torch.nn as nn
18
from tqdm.auto import tqdm
19

20
import vllm.envs as envs
21
from vllm.attention import AttentionMetadata, get_attn_backend
22
from vllm.attention.backends.abstract import AttentionState
23
from vllm.attention.backends.utils import CommonAttentionState
24
from vllm.config import CompilationLevel, VllmConfig
25
from vllm.core.scheduler import SchedulerOutputs
26
27
from vllm.distributed import get_pp_group
from vllm.distributed.kv_transfer import get_kv_transfer_group
28
29
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
                                             graph_capture)
30
from vllm.forward_context import get_forward_context, set_forward_context
31
from vllm.inputs import INPUT_REGISTRY, InputRegistry
32
from vllm.logger import init_logger
33
34
35
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
36
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
37
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
38
39
from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput,
                                                get_sampler)
40
from vllm.model_executor.model_loader import get_model
41
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
42
from vllm.model_executor.models import supports_lora, supports_multimodal
43
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
44
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
45
                             MultiModalKwargs, MultiModalPlaceholderMap,
46
                             MultiModalRegistry)
47
48
49
50
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import (
    LRUCacheWorkerPromptAdapterManager)
51
from vllm.sampling_params import SamplingParams
52
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
53
54
55
56
from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache,
                        async_tensor_h2d, flatten_2d_lists,
                        is_pin_memory_available, supports_dynamo,
                        weak_ref_tensor)
57
from vllm.worker.model_runner_base import (
58
59
    InputProcessingError, ModelRunnerBase, ModelRunnerInputBase,
    ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict,
60
61
    _add_sampling_metadata_broadcastable_dict,
    _init_attn_metadata_from_tensor_dict,
62
    _init_sampling_metadata_from_tensor_dict)
63
64
65

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend
66
67
68

logger = init_logger(__name__)

69
LORA_WARMUP_RANK = 8
70

71
_NUM_WARMUP_ITERS = 2
72

73
74
TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")

75
76
77
78
# For now, bump up cache limits for recompilations during CUDA graph warmups.
torch._dynamo.config.cache_size_limit = 128
torch._dynamo.config.accumulated_cache_size_limit = 128

79

80
@dataclass(frozen=True)
81
82
83
84
85
86
87
88
class ModelInputForGPU(ModelRunnerInputBase):
    """
    This base class contains metadata needed for the base model forward pass
    but not metadata for possible additional steps, e.g., sampling. Model
    runners that run additional steps should subclass this method to add
    additional fields.
    """
    input_tokens: Optional[torch.Tensor] = None
89
    inputs_embeds: Optional[torch.Tensor] = None
90
    input_positions: Optional[torch.Tensor] = None
91
    token_types: Optional[torch.Tensor] = None
92
93
94
95
96
    seq_lens: Optional[List[int]] = None
    query_lens: Optional[List[int]] = None
    lora_mapping: Optional["LoRAMapping"] = None
    lora_requests: Optional[Set[LoRARequest]] = None
    attn_metadata: Optional["AttentionMetadata"] = None
97
98
    prompt_adapter_mapping: Optional[PromptAdapterMapping] = None
    prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None
99
    multi_modal_kwargs: Optional[BatchedTensorInputs] = None
Mor Zusman's avatar
Mor Zusman committed
100
101
    request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
    finished_requests_ids: Optional[List[str]] = None
102
    virtual_engine: int = 0
103
    async_callback: Optional[Callable] = None
104
    scheduler_outputs: Optional[SchedulerOutputs] = None
105
    previous_hidden_states: Optional[torch.Tensor] = None
106
107
108
109

    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
110
            "inputs_embeds": self.inputs_embeds,
111
112
113
114
            "input_positions": self.input_positions,
            "lora_requests": self.lora_requests,
            "lora_mapping": self.lora_mapping,
            "multi_modal_kwargs": self.multi_modal_kwargs,
115
116
            "prompt_adapter_mapping": self.prompt_adapter_mapping,
            "prompt_adapter_requests": self.prompt_adapter_requests,
117
            "virtual_engine": self.virtual_engine,
Mor Zusman's avatar
Mor Zusman committed
118
119
            "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
            "finished_requests_ids": self.finished_requests_ids,
120
121
122
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        return tensor_dict
123
124

    @classmethod
125
126
127
128
129
130
131
132
133
134
    def from_broadcasted_tensor_dict(
        cls: Type[TModelInputForGPU],
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> TModelInputForGPU:
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)

135
136
137
138
139
140
141
142
143
144
145
146
    # Exclude `async_callback` to be able to pickle this object
    def __getstate__(self):
        state = self.__dict__.copy()
        del state["async_callback"]
        return state

    # TODO: What happens when we depickle this object?
    # How can we update this callback to properly pass it to the engine?
    def __setstate__(self, state):
        self.__dict__.update(state)
        self.__dict__.update({'async_callback': None})

147

148
@dataclass(frozen=True)
149
150
151
152
153
154
155
156
157
158
159
160
class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
    """
    Used by the ModelRunner.
    """
    sampling_metadata: Optional["SamplingMetadata"] = None
    # Used for speculative decoding. We do not broadcast it because it is only
    # used by the driver worker.
    is_prompt: Optional[bool] = None

    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        tensor_dict = {
            "input_tokens": self.input_tokens,
161
            "inputs_embeds": self.inputs_embeds,
162
163
164
165
            "input_positions": self.input_positions,
            "lora_requests": self.lora_requests,
            "lora_mapping": self.lora_mapping,
            "multi_modal_kwargs": self.multi_modal_kwargs,
166
167
            "prompt_adapter_mapping": self.prompt_adapter_mapping,
            "prompt_adapter_requests": self.prompt_adapter_requests,
168
            "virtual_engine": self.virtual_engine,
Mor Zusman's avatar
Mor Zusman committed
169
170
            "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
            "finished_requests_ids": self.finished_requests_ids,
171
172
173
174
175
        }
        _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
        _add_sampling_metadata_broadcastable_dict(tensor_dict,
                                                  self.sampling_metadata)
        return tensor_dict
176

177
178
179
180
181
182
183
184
185
186
187
188
189
    @classmethod
    def from_broadcasted_tensor_dict(
        cls,
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "ModelInputForGPUWithSamplingMetadata":
        tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        return cls(**tensor_dict)


190
class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
191
192
    """Build ModelInputForGPU from SequenceGroupMetadata."""

193
194
195
    # Note: ideally we would be using a dataclass(kw_only=True)
    # here, so that this can be subclassed easily,
    # but kw_only is not supported in python<3.10.
196
197
    class InterDataForSeqGroup:
        """Intermediate data for the current sequence group."""
198

199
200
        def simple_reinit(self):
            self.input_tokens[0].clear()  # type: ignore
201
            self.inputs_embeds = None  # type: ignore
202
            self.input_positions[0].clear()  # type: ignore
203
            self.token_types[0].clear()  # type: ignore
204
            self.mrope_input_positions = None  # type: ignore
205
206
            self.seq_lens[0] = 0  # type: ignore
            self.orig_seq_lens[0] = 0  # type: ignore
207
            self.prompt_lens[0] = 0  # type: ignore
208
209
210
211
212
213
214
215
216
            self.query_lens[0] = 0  # type: ignore
            self.context_lens[0] = 0  # type: ignore
            self.curr_sliding_window_blocks[0] = 0  # type: ignore
            self.lora_index_mapping.clear()  # type: ignore
            self.lora_prompt_mapping.clear()  # type: ignore
            self.lora_requests.clear()  # type: ignore
            self.prompt_adapter_index_mapping.clear()  # type: ignore
            self.prompt_adapter_prompt_mapping.clear()  # type: ignore

217
218
219
220
221
222
223
224
225
226
227
228
229
        def __init__(
            self,
            *,
            # From sequence group metadata.
            request_id: str,
            seq_ids: List[int],
            is_prompt: bool,
            block_tables: Optional[Dict[int, List[int]]],
            computed_block_nums: List[int],
            n_seqs: int = 0,

            # Input tokens and positions.
            input_tokens: Optional[List[List[int]]] = None,
230
            inputs_embeds: Optional[torch.Tensor] = None,
231
            input_positions: Optional[List[List[int]]] = None,
232
            token_types: Optional[List[List[int]]] = None,
233
            mrope_input_positions: Optional[List[List[List[int]]]] = None,
234
235
236
237
238
239

            # The sequence length (may be capped to the sliding window).
            seq_lens: Optional[List[int]] = None,
            # The original sequence length (before applying sliding window).
            # This is used to compute slot mapping.
            orig_seq_lens: Optional[List[int]] = None,
240
241
            # This is used in the dual-chunk flash attention backend.
            prompt_lens: Optional[List[int]] = None,
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
            # The query length.
            query_lens: Optional[List[int]] = None,
            # The number of tokens that are already computed.
            context_lens: Optional[List[int]] = None,
            # The current sliding window block.
            curr_sliding_window_blocks: Optional[List[int]] = None,

            # LoRA inputs.
            lora_index_mapping: Optional[List[List[int]]] = None,
            lora_prompt_mapping: Optional[List[List[int]]] = None,
            lora_requests: Optional[Set[LoRARequest]] = None,

            # Prompt adapter inputs.
            prompt_adapter_index_mapping: Optional[List[int]] = None,
            prompt_adapter_prompt_mapping: Optional[List[int]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None,

            # Multi-modal inputs.
260
            multi_modal_kwargs: Optional[MultiModalKwargs] = None,
261
262
            multi_modal_placeholder_maps: Optional[Dict[
                str, MultiModalPlaceholderMap]] = None,
263
264
265

            # Whether the prefix cache is hit (prefill only).
            prefix_cache_hit: bool = False,
266
267
            reinit: bool = False,
            reinit_use_defaults: bool = False,
268
            encoder_seq_len: int = 0,
269
        ):
270
271
272
273
274
275
276
            if reinit:
                assert len(self.seq_ids) == len(seq_ids)  # type: ignore
                for i, seq_id in enumerate(seq_ids):
                    self.seq_ids[i] = seq_id  # type: ignore
            else:
                self.seq_ids = seq_ids

277
278
279
280
281
            self.request_id = request_id
            self.is_prompt = is_prompt
            self.block_tables = block_tables
            self.computed_block_nums = computed_block_nums
            self.n_seqs = n_seqs
282
            self.encoder_seq_len = encoder_seq_len
283

284
285
286
287
288
289
290
291
292
293
            if reinit:
                if len(self.seq_ids) == 1 and reinit_use_defaults:
                    self.simple_reinit()
                else:
                    if input_tokens:
                        self.input_tokens = input_tokens
                    else:
                        for seq_id in range(len(self.seq_ids)):
                            self.input_tokens[seq_id].clear()

294
295
                    self.inputs_embeds = inputs_embeds

296
297
298
299
300
301
                    if input_positions:
                        self.input_positions = input_positions
                    else:
                        for seq_id in range(len(self.seq_ids)):
                            self.input_positions[seq_id].clear()

302
303
304
305
306
307
                    if token_types:
                        self.token_types = token_types
                    else:
                        for seq_id in range(len(self.seq_ids)):
                            self.token_types[seq_id].clear()

308
309
                    self.mrope_input_positions = None

310
311
312
313
314
315
316
317
318
319
320
321
                    if seq_lens:
                        self.seq_lens = seq_lens
                    else:
                        for seq_id in range(len(self.seq_ids)):
                            self.seq_lens[seq_id] = 0

                    if orig_seq_lens:
                        self.orig_seq_lens = orig_seq_lens
                    else:
                        for seq_id in range(len(self.seq_ids)):
                            self.orig_seq_lens[seq_id] = 0

322
323
324
325
326
327
                    if prompt_lens:
                        self.prompt_lens = prompt_lens
                    else:
                        for seq_id in range(len(self.seq_ids)):
                            self.prompt_lens[seq_id] = 0

328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
                    if query_lens:
                        self.query_lens = query_lens
                    else:
                        for seq_id in range(len(self.seq_ids)):
                            self.query_lens[seq_id] = 0

                    if context_lens:
                        self.context_lens = context_lens
                    else:
                        for seq_id in range(len(self.seq_ids)):
                            self.context_lens[seq_id] = 0

                    if curr_sliding_window_blocks:
                        self.curr_sliding_window_blocks = \
                            curr_sliding_window_blocks
                    else:
                        for seq_id in range(len(self.seq_ids)):
                            self.curr_sliding_window_blocks[seq_id] = 0

                    if lora_index_mapping:
                        self.lora_index_mapping = lora_index_mapping
                    else:
                        self.lora_index_mapping.clear()

                    if lora_prompt_mapping:
                        self.lora_prompt_mapping = lora_prompt_mapping
                    else:
                        self.lora_prompt_mapping.clear()

                    if lora_requests:
                        self.lora_requests = lora_requests
                    else:
                        self.lora_requests.clear()

                    if prompt_adapter_index_mapping:
                        self.prompt_adapter_index_mapping = \
                            prompt_adapter_index_mapping
                    else:
                        self.prompt_adapter_index_mapping.clear()

                    if prompt_adapter_prompt_mapping:
                        self.prompt_adapter_prompt_mapping = \
                            prompt_adapter_prompt_mapping
                    else:
                        self.prompt_adapter_prompt_mapping.clear()

            else:
                self.input_tokens = input_tokens or []
376
                self.inputs_embeds = inputs_embeds
377
                self.input_positions = input_positions or []
378
                self.token_types = token_types or []
379
                self.mrope_input_positions = mrope_input_positions or None
380
381
                self.seq_lens = seq_lens or []
                self.orig_seq_lens = orig_seq_lens or []
382
                self.prompt_lens = prompt_lens or []
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
                self.query_lens = query_lens or []
                self.context_lens = context_lens or []
                self.curr_sliding_window_blocks = \
                    curr_sliding_window_blocks or []

                self.lora_index_mapping = lora_index_mapping or []
                self.lora_prompt_mapping = lora_prompt_mapping or []
                self.lora_requests = lora_requests or set()

                self.prompt_adapter_index_mapping = (
                    prompt_adapter_index_mapping or [])
                self.prompt_adapter_prompt_mapping = (
                    prompt_adapter_prompt_mapping or [])

            self.prompt_adapter_request = prompt_adapter_request
398
            self.multi_modal_kwargs = multi_modal_kwargs
399
            self.multi_modal_placeholder_maps = multi_modal_placeholder_maps
400
401
            self.prefix_cache_hit = prefix_cache_hit

402
403
            self.n_seqs = len(self.seq_ids)

404
405
            if not reinit:
                self.__post_init__()
406
407
408
409
410
411

        def __post_init__(self):
            self.n_seqs = len(self.seq_ids)

            self.input_tokens = [[] for _ in range(self.n_seqs)]
            self.input_positions = [[] for _ in range(self.n_seqs)]
412
            self.token_types = [[] for _ in range(self.n_seqs)]
413
            self.mrope_input_positions = None
414
415
            self.seq_lens = [0] * self.n_seqs
            self.orig_seq_lens = [0] * self.n_seqs
416
            self.prompt_lens = [0] * self.n_seqs
417
418
419
420
            self.query_lens = [0] * self.n_seqs
            self.context_lens = [0] * self.n_seqs
            self.curr_sliding_window_blocks = [0] * self.n_seqs

421
422
423
            self.lora_index_mapping = []
            self.lora_prompt_mapping = []

424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
        def __repr__(self) -> str:
            return (f"InterDataForSeqGroup("
                    f"request_id={self.request_id}, "
                    f"seq_ids={self.seq_ids}, "
                    f"is_prompt={self.is_prompt}, "
                    f"block_tables={self.block_tables}, "
                    f"computed_block_nums={self.computed_block_nums}, "
                    f"n_seqs={self.n_seqs}, "
                    f"input_tokens={self.input_tokens}, "
                    f"inputs_embeds.shape="
                    f"{getattr(self.inputs_embeds, 'shape', None)}, "
                    f"input_positions={self.input_positions}, "
                    f"token_types={self.token_types}, "
                    f"mrope_input_positions={self.mrope_input_positions}, "
                    f"seq_lens={self.seq_lens}, "
                    f"orig_seq_lens={self.orig_seq_lens}, "
                    f"query_lens={self.query_lens}, "
                    f"context_lens={self.context_lens}, "
                    f"multi_modal_kwargs={self.multi_modal_kwargs}")

444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
    def gen_inter_data_builder(self, num_seqs: int):
        return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup(
            request_id="",
            seq_ids=[0] * num_seqs,
            is_prompt=True,
            block_tables=None,
            computed_block_nums=[])

    def init_cached_inter_data(self, *args, **kwargs):
        assert len(args) == 0
        assert "seq_ids" in kwargs
        seq_ids = kwargs["seq_ids"]
        num_seqs = len(seq_ids)

        # The inter-data cache is per model_runner
        inter_data_cache = self.runner.inter_data_cache
        if num_seqs not in inter_data_cache:
            inter_data_cache[num_seqs] = PyObjectCache(
                self.gen_inter_data_builder(num_seqs))

        obj = inter_data_cache[num_seqs].get_object()
        obj.__init__(*args, **kwargs)
        return obj

    def reset_cached_inter_data(self):
        for cache in self.runner.inter_data_cache.values():
            cache.reset()
471
472
473
474
475

    def __init__(self,
                 runner: "GPUModelRunnerBase",
                 finished_requests_ids: Optional[List[str]] = None):
        super().__init__()
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
        # Compute functions for each sequence in a sequence group.
        # WARNING: The order of the functions matters!
        self.per_seq_compute_fns = [
            self._compute_lens,
            self._compute_for_prefix_cache_hit,
            self._compute_for_sliding_window,
            self._compute_lora_input,
        ]
        # Compute functions for each sequence group.
        # WARNING: The order of the functions matters!
        self.per_seq_group_compute_fns = [
            self._compute_prompt_adapter_input,
            self._compute_multi_modal_input,
        ]

491
492
493
494
495
496
497
498
499
500
501
        self.runner = runner
        self.model_input_cls = self.runner._model_input_cls
        self.attn_backend = self.runner.attn_backend
        self.scheduler_config = self.runner.scheduler_config
        self.sliding_window = self.runner.sliding_window
        self.block_size = self.runner.block_size
        self.enable_lora = self.runner.lora_config is not None
        self.enable_prompt_adapter = (self.runner.prompt_adapter_config
                                      is not None)

        # Attention metadata inputs.
502
503
504
505
        if self.attn_backend is not None:
            # spec decode (e.g. Medusa) does not have atten backend
            self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
                weakref.proxy(self))
506
507
508
509
510
511
512
513
514
515
516

        # Engine/Model configurations.
        self.chunked_prefill_enabled = (
            self.scheduler_config is not None
            and self.scheduler_config.chunked_prefill_enabled)
        if self.sliding_window is not None:
            self.sliding_window_blocks = (
                self.sliding_window + self.block_size - 1) // self.block_size
            self.block_aligned_sliding_window = \
                self.sliding_window_blocks * self.block_size

517
518
519
520
    def prepare(self,
                finished_requests_ids: Optional[List[str]] = None) -> None:
        self.finished_requests_ids = finished_requests_ids

521
522
523
524
        # if the current batch is decode-only.
        # will be set to False if there is any non-decode request.
        self.decode_only = True

525
526
527
528
529
530
531
        # Intermediate data (data in CPU before going to GPU) for
        # the current sequence group.
        self.inter_data_list: List[
            ModelInputForGPUBuilder.InterDataForSeqGroup] = []

        self.attn_metadata_builder.prepare()

532
533
534
535
536
537
538
    def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int,
                      seq_group_metadata: SequenceGroupMetadata):
        """Compute context length, sequence length and tokens
        for the given sequence data.
        """
        seq_data = seq_group_metadata.seq_data[inter_data.seq_ids[seq_idx]]
        token_chunk_size = seq_group_metadata.token_chunk_size
539

540
541
        # Compute context length (the number of tokens that are
        # already computed) and sequence length (total number of tokens).
542

543
544
545
        seq_len = seq_data.get_len()
        if inter_data.is_prompt:
            context_len = seq_data.get_num_computed_tokens()
546
547
            seq_len = min(seq_len, context_len + token_chunk_size)
        elif self.runner.scheduler_config.is_multi_step or \
548
            self.runner.model_config.is_encoder_decoder:
549
            context_len = seq_len - 1
550
551
        else:
            context_len = seq_data.get_num_computed_tokens()
552
553

        # Compute tokens.
554
555
556
557
558
559
560
561
        if seq_data.prompt_embeds is None:
            tokens = seq_data.get_token_ids()[context_len:seq_len]
            prompt_embeds = None
        else:
            tokens = [0] * (seq_len - context_len)
            prompt_embeds = seq_data.get_token_embeddings(
            )[context_len:seq_len]

562
        token_types = seq_group_metadata.token_type_ids
563
564
565

        inter_data.seq_lens[seq_idx] = seq_len
        inter_data.orig_seq_lens[seq_idx] = seq_len
566
        inter_data.prompt_lens[seq_idx] = seq_data.get_prompt_len()
567
        inter_data.context_lens[seq_idx] = context_len
568
        inter_data.input_tokens[seq_idx].extend(tokens)
569
        inter_data.inputs_embeds = prompt_embeds
570
        inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
571
572
        inter_data.token_types[seq_idx].extend(
            token_types if token_types else [])
573
        inter_data.query_lens[seq_idx] = seq_len - context_len
574

575
576
577
578
579
580
581
582
583
584
585
        if seq_data.mrope_position_delta is not None:
            if inter_data.mrope_input_positions is None:
                inter_data.mrope_input_positions = [None] * inter_data.n_seqs

            inter_data.mrope_input_positions[
                seq_idx] = MRotaryEmbedding.get_next_input_positions(
                    seq_data.mrope_position_delta,
                    context_len,
                    seq_len,
                )

586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
    def _compute_for_prefix_cache_hit(
            self, inter_data: InterDataForSeqGroup, seq_idx: int,
            seq_group_metadata: SequenceGroupMetadata):
        """Check if hit prefix cache (i.e., some blocks are already computed).
        If hit, update input tokens and positions to only compute the
        remaining blocks.
        """
        computed_block_nums = inter_data.computed_block_nums

        # Note that prefix caching does not support sliding window.
        prefix_cache_hit = (computed_block_nums is not None
                            and len(computed_block_nums) > 0
                            and self.sliding_window is None
                            and inter_data.is_prompt)
        inter_data.prefix_cache_hit = prefix_cache_hit
601
602
603
604
605
606
607
608
609

        if not prefix_cache_hit:
            return

        assert computed_block_nums is not None
        # The cache hit prompt tokens in this sequence. Note that
        # this may be larger than the sequence length if chunked
        # prefill is enabled.
        prefix_cache_len = len(computed_block_nums) * self.block_size
610
611
612
        seq_group_metadata.seq_data[inter_data.seq_ids[
            seq_idx]].update_num_cached_tokens(prefix_cache_len)

613
614
615
616
617
618
619
620
621
622
623
624
625
        # The number of so far computed prompt tokens in this sequence.
        context_len = inter_data.context_lens[seq_idx]
        # The total number of prompt tokens in this sequence.
        # When chunked prefill is enabled, this is the token number of
        # computed chunks + current chunk.
        seq_len = inter_data.seq_lens[seq_idx]
        if prefix_cache_len <= context_len:
            # We already passed the cache hit region,
            # so do normal computation.
            pass
        elif context_len < prefix_cache_len < seq_len:
            # Partial hit. Compute the missing part.
            uncomputed_start = prefix_cache_len - context_len
626
            inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
627
                seq_idx][uncomputed_start:]
628
            inter_data.input_positions[seq_idx] = inter_data.input_positions[
629
                seq_idx][uncomputed_start:]
630
631
            inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][
                uncomputed_start:]
632
633
            context_len = prefix_cache_len

634
635
636
            inter_data.context_lens[seq_idx] = context_len
            inter_data.query_lens[
                seq_idx] = inter_data.seq_lens[seq_idx] - context_len
637
638
639
640
641
642
643
644
645
        elif seq_len <= prefix_cache_len:
            # Full hit. Only compute the last token to avoid
            # erroneous behavior. FIXME: Ideally we should directly
            # mark all tokens as computed in the scheduler and do not
            # schedule this sequence, so this case should not happen.
            inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
                seq_idx][-1:]
            inter_data.input_positions[seq_idx] = inter_data.input_positions[
                seq_idx][-1:]
646
647
            inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][
                -1:]
648
649
            inter_data.query_lens[seq_idx] = 1
            inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1
650
651
652
653
654
655
656
657
658
659
660
661
662
663

    def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup,
                                    seq_idx: int,
                                    seq_group_metadata: SequenceGroupMetadata):
        """Update seq_len and curr_sliding_window_block for the given
        sequence data (only required by decoding) if sliding window is enabled.
        """
        curr_sliding_window_block = 0
        sliding_seq_len = inter_data.seq_lens[seq_idx]
        if not inter_data.is_prompt and self.sliding_window is not None:
            # TODO(sang): This is a hack to make sliding window work with
            # paged attn. We can remove it if we make paged attn kernel
            # to properly handle slinding window attn.
            curr_sliding_window_block = self.sliding_window_blocks
664
665
666
667
668
669
            # number of elements in last block
            suff_len = inter_data.seq_lens[seq_idx] % self.block_size
            sliding_seq_len = min(inter_data.seq_lens[seq_idx],
                                  self.block_aligned_sliding_window + suff_len)
            if suff_len > 0:
                curr_sliding_window_block += 1
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686

        inter_data.curr_sliding_window_blocks[
            seq_idx] = curr_sliding_window_block
        inter_data.seq_lens[seq_idx] = sliding_seq_len

    def _compute_lora_input(self, inter_data: InterDataForSeqGroup,
                            seq_idx: int,
                            seq_group_metadata: SequenceGroupMetadata):
        """If LoRA is enabled, compute LoRA index and prompt mapping."""
        if not self.enable_lora:
            return

        lora_id = seq_group_metadata.lora_int_id
        if lora_id > 0:
            inter_data.lora_requests.add(seq_group_metadata.lora_request)
        query_len = inter_data.query_lens[seq_idx]
        inter_data.lora_index_mapping.append([lora_id] * query_len)
687
688
689
690
691
692
693
        sampling_params = seq_group_metadata.sampling_params
        if sampling_params and sampling_params.prompt_logprobs is not None:
            inter_data.lora_prompt_mapping.append([lora_id] * query_len)
        elif not self.chunked_prefill_enabled or seq_group_metadata.do_sample:
            inter_data.lora_prompt_mapping.append([lora_id])
        else:
            inter_data.lora_prompt_mapping.append([])
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725

    def _compute_prompt_adapter_input(
            self, inter_data: InterDataForSeqGroup,
            seq_group_metadata: SequenceGroupMetadata):
        """If prompt adapter is enabled, compute index and prompt mapping.
        """
        # Note that when is_prompt=True, we expect only one sequence
        # in the group.
        if not self.enable_prompt_adapter:
            return

        prompt_adapter_id = seq_group_metadata.prompt_adapter_id
        if prompt_adapter_id <= 0 or not inter_data.is_prompt:
            return

        # We expect only one sequence in the group when is_prompt=True.
        assert inter_data.n_seqs == 1
        query_len = inter_data.query_lens[0]
        inter_data.prompt_adapter_request = (
            seq_group_metadata.prompt_adapter_request)

        num_tokens = seq_group_metadata.prompt_adapter_num_virtual_tokens
        inter_data.prompt_adapter_index_mapping = [
            prompt_adapter_id
        ] * num_tokens + [0] * (query_len - num_tokens)
        inter_data.prompt_adapter_prompt_mapping = [prompt_adapter_id] * (
            query_len if seq_group_metadata.sampling_params
            and seq_group_metadata.sampling_params.prompt_logprobs else 1)

    def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
                                   seq_group_metadata: SequenceGroupMetadata):
        """If multi-modal data is given, add it to the input."""
726
        # NOTE: mm_kwargs only includes the subset of multi-modal items that
727
728
        # intersect with the current prefill positions.
        positions = inter_data.input_positions[0]
729
        mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
730
731
            seq_group_metadata,
            range(positions[0], positions[0] + len(positions)))
732
        if not mm_kwargs:
733
734
            return

735
        inter_data.multi_modal_kwargs = mm_kwargs
736
        inter_data.multi_modal_placeholder_maps = placeholder_maps
737

738
        # special processing for mrope position deltas.
739
        if self.runner.model_config.uses_mrope:
740
741
            image_grid_thw = mm_kwargs.get("image_grid_thw", None)
            video_grid_thw = mm_kwargs.get("video_grid_thw", None)
742
743
744
745
746
747
748
749
            audio_feature_lengths = mm_kwargs.get("audio_feature_lengths",
                                                  None)
            assert (
                image_grid_thw is not None or video_grid_thw is not None
                or audio_feature_lengths is not None), (
                    "mrope embedding type requires multi-modal input mapper "
                    "returns 'image_grid_thw' or 'video_grid_thw' or "
                    "'audio_feature_lengths'.")
750

Roger Wang's avatar
Roger Wang committed
751
            second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None)
752
            use_audio_in_video = mm_kwargs.get("use_audio_in_video", False)
753
754
755
756
757
758
759
760
761
762
763
            hf_config = self.runner.model_config.hf_config

            inter_data.mrope_input_positions = [None] * inter_data.n_seqs
            for seq_idx in range(inter_data.n_seqs):
                seq_data = seq_group_metadata.seq_data[
                    inter_data.seq_ids[seq_idx]]
                token_ids = seq_data.get_token_ids()

                mrope_input_positions, mrope_position_delta = \
                    MRotaryEmbedding.get_input_positions(
                        token_ids,
Roger Wang's avatar
Roger Wang committed
764
                        hf_config=hf_config,
765
766
                        image_grid_thw=image_grid_thw,
                        video_grid_thw=video_grid_thw,
Roger Wang's avatar
Roger Wang committed
767
                        second_per_grid_ts=second_per_grid_ts,
768
                        context_len=inter_data.context_lens[seq_idx],
769
                        seq_len=inter_data.seq_lens[seq_idx],
770
771
                        audio_feature_lengths=audio_feature_lengths,
                        use_audio_in_video=use_audio_in_video,
772
773
774
775
776
777
                    )

                seq_data.mrope_position_delta = mrope_position_delta
                inter_data.mrope_input_positions[
                    seq_idx] = mrope_input_positions

778
    def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
779
        """Add a sequence group to the builder."""
780
        seq_ids = seq_group_metadata.seq_data.keys()
781
782
783
784
785
786
787
        n_seqs = len(seq_ids)
        is_prompt = seq_group_metadata.is_prompt

        if is_prompt:
            assert n_seqs == 1
            self.decode_only = False

788
789
        encoder_seq_len = 0

790
        if self.runner.model_config.is_encoder_decoder:
791
792
            encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len()

793
        inter_data = self.init_cached_inter_data(
794
795
796
797
            request_id=seq_group_metadata.request_id,
            seq_ids=seq_ids,
            is_prompt=is_prompt,
            block_tables=seq_group_metadata.block_tables,
798
799
            computed_block_nums=seq_group_metadata.computed_block_nums,
            reinit=True,
800
801
            reinit_use_defaults=True,
            encoder_seq_len=encoder_seq_len)
802

803
        self.inter_data_list.append(inter_data)
804

805
806
807
808
809
        for seq_idx in range(n_seqs):
            for per_seq_fn in self.per_seq_compute_fns:
                per_seq_fn(inter_data, seq_idx, seq_group_metadata)
        for per_seq_group_fn in self.per_seq_group_compute_fns:
            per_seq_group_fn(inter_data, seq_group_metadata)
810

811
812
    def _use_captured_graph(self,
                            batch_size: int,
813
                            decode_only: bool,
814
815
                            max_decode_seq_len: int,
                            max_encoder_seq_len: int = 0) -> bool:
816
        return (decode_only and not self.runner.model_config.enforce_eager
817
818
819
                and max_decode_seq_len <= self.runner.max_seq_len_to_capture
                and max_encoder_seq_len <= self.runner.max_seq_len_to_capture
                and batch_size <= self.runner.max_batchsize_to_capture)
820

821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
    def _get_cuda_graph_pad_size(self,
                                 num_seqs: int,
                                 max_decode_seq_len: int,
                                 max_encoder_seq_len: int = 0) -> int:
        """
        Determine the number of padding sequences required for running in
        CUDA graph mode. Returns -1 if CUDA graphs cannot be used.

        In the multi-step + chunked-prefill case, only the first step
        has Prefills (if any). The rest of the steps are guaranteed to be all
        decodes. In this case, we set up the padding as if all the sequences
        are decodes so we may run all steps except the first step in CUDA graph
        mode. The padding is accounted for in the multi-step `advance_step`
        family of functions.

        Args:
837
            num_seqs (int): Number of sequences scheduled to run.
838
839
840
841
842
            max_decode_seq_len (int): Greatest of all the decode sequence
                lengths. Used only in checking the viablility of using
                CUDA graphs.
            max_encoder_seq_len (int, optional): Greatest of all the encode
                sequence lengths. Defaults to 0. Used only in checking the
843
                viability of using CUDA graphs.
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
        Returns:
            int: Returns the determined number of padding sequences. If
                CUDA graphs is not viable, returns -1.
        """
        is_mscp: bool = self.runner.scheduler_config.is_multi_step and \
                    self.runner.scheduler_config.chunked_prefill_enabled
        decode_only = self.decode_only or is_mscp
        if not decode_only:
            # Early exit so we can treat num_seqs as the batch_size below.
            return -1

        # batch_size out of this function refers to the number of input
        # tokens being scheduled. This conflation of num_seqs as batch_size
        # is valid as this is a decode-only case.
        batch_size = num_seqs
        if not self._use_captured_graph(batch_size, decode_only,
                                        max_decode_seq_len,
                                        max_encoder_seq_len):
            return -1

864
865
        graph_batch_size = self.runner.vllm_config.pad_for_cudagraph(
            batch_size)
866
867
868
        assert graph_batch_size >= batch_size
        return graph_batch_size - batch_size

869
    def build(self) -> ModelInputForGPU:
870
871
872
873
        """Finalize the builder intermediate data and
        create on-device tensors.
        """
        # Combine and flatten intermediate data.
874
875
876
        input_tokens = list[int]()
        inputs_embeds_lst = list[torch.Tensor]()
        token_types = list[int]()
877
878
879
        for inter_data in self.inter_data_list:
            for cur_input_tokens in inter_data.input_tokens:
                input_tokens.extend(cur_input_tokens)
880
881
            for cur_token_types in inter_data.token_types:
                token_types.extend(cur_token_types)
882
883
884
885
886
887
888
889
890
891
892
893
894
            if inter_data.inputs_embeds is not None:
                inputs_embeds_lst.append(
                    inter_data.inputs_embeds.to(
                        dtype=self.runner.model_config.dtype,
                        device=self.runner.device))
        inputs_embeds: Optional[torch.Tensor]
        if len(inputs_embeds_lst) == 0:
            inputs_embeds = None
        else:
            inputs_embeds = torch.cat(inputs_embeds_lst, dim=0).to(
                dtype=self.runner.model_config.dtype,
                device=self.runner.device)
            assert len(inputs_embeds) == len(input_tokens)
895

896
        if not input_tokens and inputs_embeds is None:
897
898
            # This may happen when all prefill requests hit
            # prefix caching and there is no decode request.
899
            return self.model_input_cls()
900

901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
        mrope_input_positions: Optional[List[List[int]]] = None
        if any(inter_data.mrope_input_positions is not None
               for inter_data in self.inter_data_list):
            mrope_input_positions = [[] for _ in range(3)]
            for idx in range(3):
                for inter_data in self.inter_data_list:
                    msections = inter_data.mrope_input_positions
                    if msections is None:
                        for _seq_input_positions in inter_data.input_positions:
                            mrope_input_positions[idx].extend(
                                _seq_input_positions)
                    else:
                        for _seq_mrope_input_positions in msections:
                            mrope_input_positions[idx].extend(
                                _seq_mrope_input_positions[idx])
            input_positions = None
        else:
            input_positions = []
            for inter_data in self.inter_data_list:
                for cur_input_positions in inter_data.input_positions:
                    input_positions.extend(cur_input_positions)
922

923
        seq_lens = []
924
        query_lens = []
925
        max_decode_seq_len = 0
926
        max_encoder_seq_len = 0
927
928
        for inter_data in self.inter_data_list:
            seq_lens.extend(inter_data.seq_lens)
929
            query_lens.extend(inter_data.query_lens)
930
931
932
            if not inter_data.is_prompt:
                max_decode_seq_len = max(max_decode_seq_len,
                                         max(inter_data.seq_lens))
933
                if self.runner.model_config.is_encoder_decoder:
934
935
                    max_encoder_seq_len = max(max_encoder_seq_len,
                                              inter_data.encoder_seq_len)
936

937
938
939
940
941
942
        # Mapping from request IDs to sequence IDs. Used for Jamba models
        # that manages the cache by itself.
        request_ids_to_seq_ids = {
            data.request_id: data.seq_ids
            for data in self.inter_data_list
        }
943

944
945
        cuda_graph_pad_size = self._get_cuda_graph_pad_size(
            num_seqs=len(seq_lens),
946
            max_decode_seq_len=max_decode_seq_len,
947
            max_encoder_seq_len=max_encoder_seq_len)
948

949
950
951
952
953
954
        batch_size = len(input_tokens)
        if cuda_graph_pad_size != -1:
            # If cuda graph can be used, pad tensors accordingly.
            # See `capture_model` API for more details.
            # vLLM uses cuda graph only for decoding requests.
            batch_size += cuda_graph_pad_size
955
956

        # Tokens and positions.
957
958
        if cuda_graph_pad_size:
            input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
959
960
961
962
        assert self.runner.device is not None
        input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
                                               self.runner.device,
                                               self.runner.pin_memory)
963
964
965
966
967
968

        token_types_tensor = async_tensor_h2d(token_types, torch.long,
                                               self.runner.device,
                                               self.runner.pin_memory) \
                                                if token_types else None

969
970
971
972
973
974
975
976
977
978
979
980
981
982
        if mrope_input_positions is not None:
            for idx in range(3):
                mrope_input_positions[idx].extend(
                    itertools.repeat(0, cuda_graph_pad_size))
            input_positions_tensor = async_tensor_h2d(mrope_input_positions,
                                                      torch.long,
                                                      self.runner.device,
                                                      self.runner.pin_memory)
        else:
            input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
            input_positions_tensor = async_tensor_h2d(input_positions,
                                                      torch.long,
                                                      self.runner.device,
                                                      self.runner.pin_memory)
983
        # Sequence and query lengths.
984
985
        if cuda_graph_pad_size:
            seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
986
987
988

        # Attention metadata.
        attn_metadata = self.attn_metadata_builder.build(
989
            seq_lens, query_lens, cuda_graph_pad_size, batch_size)
990
991

        # LoRA data.
992
993
        lora_requests = set()
        lora_mapping = None
994
        if self.enable_lora:
995
996
997
998
999
1000
            lora_requests = set(r for data in self.inter_data_list
                                for r in data.lora_requests)
            lora_index_mapping = flatten_2d_lists([
                flatten_2d_lists(inter_data.lora_index_mapping)
                for inter_data in self.inter_data_list
            ])
1001
1002
1003
            if cuda_graph_pad_size:
                lora_index_mapping.extend(
                    itertools.repeat(0, cuda_graph_pad_size))
1004
1005
1006
1007
            lora_prompt_mapping = flatten_2d_lists([
                flatten_2d_lists(inter_data.lora_prompt_mapping)
                for inter_data in self.inter_data_list
            ])
1008

1009
            lora_mapping = LoRAMapping(
1010
1011
1012
                **dict(index_mapping=lora_index_mapping,
                       prompt_mapping=lora_prompt_mapping,
                       is_prefill=not self.decode_only))
1013
1014

        # Prompt adapter data.
1015
1016
        prompt_adapter_requests: Set[PromptAdapterRequest] = set()
        prompt_adapter_mapping = None
1017
        if self.enable_prompt_adapter:
1018
1019
1020
1021
1022
1023
1024
            prompt_adapter_requests = set(
                data.prompt_adapter_request for data in self.inter_data_list
                if data.prompt_adapter_request is not None)
            prompt_adapter_index_mapping = flatten_2d_lists([
                inter_data.prompt_adapter_index_mapping
                for inter_data in self.inter_data_list
            ])
1025
1026
1027
            if cuda_graph_pad_size:
                prompt_adapter_index_mapping.extend(
                    itertools.repeat(0, cuda_graph_pad_size))
1028
1029
1030
1031
            prompt_adapter_prompt_mapping = flatten_2d_lists([
                inter_data.prompt_adapter_prompt_mapping
                for inter_data in self.inter_data_list
            ])
1032
            prompt_adapter_mapping = PromptAdapterMapping(
1033
1034
                prompt_adapter_index_mapping,
                prompt_adapter_prompt_mapping,
1035
1036
1037
            )

        # Multi-modal data.
1038
1039
1040
        multi_modal_kwargs_list = [
            data.multi_modal_kwargs for data in self.inter_data_list
            if data.multi_modal_kwargs is not None
1041
        ]
1042
        multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
1043
1044
1045

        return self.model_input_cls(
            input_tokens=input_tokens_tensor,
1046
            inputs_embeds=inputs_embeds,
1047
            input_positions=input_positions_tensor,
1048
            token_types=token_types_tensor,
1049
            attn_metadata=attn_metadata,
1050
1051
            seq_lens=seq_lens,
            query_lens=query_lens,
1052
            lora_mapping=lora_mapping,
1053
            lora_requests=lora_requests,
1054
            multi_modal_kwargs=multi_modal_kwargs,
1055
            request_ids_to_seq_ids=request_ids_to_seq_ids,
1056
1057
            finished_requests_ids=self.finished_requests_ids,
            prompt_adapter_mapping=prompt_adapter_mapping,
1058
            prompt_adapter_requests=prompt_adapter_requests)
1059
1060


1061
1062
1063
1064
1065
class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
    """
    Helper class for shared methods between GPU model runners.
    """
    _model_input_cls: Type[TModelInputForGPU]
1066
    _builder_cls: Type[ModelInputForGPUBuilder]
1067
    builder: ModelInputForGPUBuilder
1068
1069
1070

    def __init__(
        self,
1071
        vllm_config: VllmConfig,
1072
        kv_cache_dtype: Optional[str] = "auto",
1073
        is_driver_worker: bool = False,
1074
        return_hidden_states: bool = False,
1075
1076
        input_registry: InputRegistry = INPUT_REGISTRY,
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
1077
    ):
1078
1079
1080
1081
1082

        ModelRunnerBase.__init__(self, vllm_config)
        model_config = self.model_config
        cache_config = self.cache_config

1083
        self.is_driver_worker = is_driver_worker
1084
        self.return_hidden_states = return_hidden_states
1085

1086
        self.device = self.device_config.device
1087
        self.pin_memory = is_pin_memory_available()
1088

1089
1090
1091
1092
        self.kv_cache_dtype = kv_cache_dtype
        self.sliding_window = model_config.get_sliding_window()
        self.block_size = cache_config.block_size
        self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
1093
1094
        self.max_batchsize_to_capture = \
            self.vllm_config.compilation_config.max_capture_size
1095

1096
1097
        #
        self.graph_runners: List[Dict[Tuple[int, bool], CUDAGraphRunner]] = [
1098
1099
            {} for _ in range(self.parallel_config.pipeline_parallel_size)
        ]
1100
1101
        self.graph_memory_pool: Optional[Tuple[
            int, int]] = None  # Set during graph capture.
Mor Zusman's avatar
Mor Zusman committed
1102

1103
        self.has_inner_state = model_config.has_inner_state
Mor Zusman's avatar
Mor Zusman committed
1104

1105
1106
        self.in_profile_run = False

1107
        # When using CUDA graph, the input block tables must be padded to
1108
        # max_seq_len_to_capture. However, creating the block table in
1109
1110
1111
        # Python can be expensive. To optimize this, we cache the block table
        # in numpy and only copy the actual input content at every iteration.
        # The shape of the cached block table will be
1112
        # (max batch size to capture, max seq len to capture / block size).
1113
        self.graph_block_tables = np.zeros(
1114
            (self.max_batchsize_to_capture, self.get_max_block_per_batch()),
1115
            dtype=np.int32)
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126

        # Attention-free but stateful models like Mamba need a placeholder attn
        # backend, as the attention metadata is needed to manage internal state.
        # However we must bypass attention selection altogether for some models
        # used for speculative decoding to avoid a divide-by-zero in
        # model_config.get_head_size()
        num_attn_heads = self.model_config.get_num_attention_heads(
            self.parallel_config)
        needs_attn_backend = (num_attn_heads != 0
                              or self.model_config.is_attention_free)

1127
1128
1129
1130
1131
        self.attn_backend = get_attn_backend(
            self.model_config.get_head_size(),
            self.model_config.dtype,
            self.kv_cache_dtype,
            self.block_size,
1132
            self.model_config.is_attention_free,
1133
            use_mla=self.model_config.use_mla,
1134
1135
1136
1137
1138
1139
        ) if needs_attn_backend else None
        if self.attn_backend:
            self.attn_state = self.attn_backend.get_state_cls()(
                weakref.proxy(self))
        else:
            self.attn_state = CommonAttentionState(weakref.proxy(self))
1140

1141
        # Multi-modal data support
1142
1143
        self.input_registry = input_registry
        self.mm_registry = mm_registry
1144

1145
        # Lazy initialization
1146
        self.model: nn.Module  # Set after load_model
1147
1148
        # Set after load_model.
        self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
1149
        self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
1150
        self.sampler = get_sampler()
1151

1152
1153
1154
        set_cpu_offload_max_bytes(
            int(self.cache_config.cpu_offload_gb * 1024**3))

1155
1156
        # Used to cache python objects
        self.inter_data_cache: Dict[int, PyObjectCache] = {}
1157
1158
1159
1160
1161
1162
1163

        # Using the PythonizationCache in Pipeline-Parallel clobbers the
        # SequenceGroupToSample object. In Pipeline-Parallel, we have
        # more than 1 Scheduler, resulting in a potential back-to-back
        # prepare_model_inputs() call. This clobbers the cached
        # SequenceGroupToSample objects, as we reset the cache during
        # every prepare_model_inputs() call.
1164
        self.sampling_metadata_cache: SamplingMetadataCache = \
1165
1166
              SamplingMetadataCache() \
                if self.parallel_config.pipeline_parallel_size == 1 else None
1167

1168
1169
1170
1171
        if hasattr(self, "_builder_cls"):
            # multi-step model runner does not have `_builder_cls`
            self.builder = self._builder_cls(weakref.proxy(self))

1172
    def load_model(self) -> None:
1173
        logger.info("Starting to load model %s...", self.model_config.model)
1174
        with DeviceMemoryProfiler(self.device) as m:
1175
            time_before_load = time.perf_counter()
1176
            self.model = get_model(vllm_config=self.vllm_config)
1177
1178
1179
1180
1181
1182
1183
1184
1185
            if self.lora_config:
                assert supports_lora(
                    self.model
                ), f"{self.model.__class__.__name__} does not support LoRA yet."

                if supports_multimodal(self.model):
                    logger.warning(
                        "Regarding multimodal models, vLLM currently "
                        "only supports adding LoRA to language model.")
1186
1187
1188

                # Use get_text_config() in case of multimodal models
                text_config = self.model_config.hf_config.get_text_config()
1189
1190
1191
1192
1193
1194
1195
1196
1197

                self.lora_manager = LRUCacheWorkerLoRAManager(
                    self.scheduler_config.max_num_seqs,
                    self.scheduler_config.max_num_batched_tokens,
                    self.vocab_size,
                    self.lora_config,
                    self.device,
                    self.model.embedding_modules,
                    self.model.embedding_padding_modules,
1198
1199
                    max_position_embeddings=text_config.
                    max_position_embeddings,
1200
1201
                )
                self.model = self.lora_manager.create_lora_manager(self.model)
1202
            time_after_load = time.perf_counter()
1203
1204

        self.model_memory_usage = m.consumed_memory
1205
1206
        logger.info("Model loading took %.4f GiB and %.6f seconds",
                    self.model_memory_usage / GiB_bytes,
1207
                    time_after_load - time_before_load)
1208
1209
1210
1211
1212
1213
1214
1215
1216
        if self.prompt_adapter_config:
            self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
                self.scheduler_config.max_num_seqs,
                self.scheduler_config.max_num_batched_tokens, self.device,
                self.prompt_adapter_config)
            self.model = (
                self.prompt_adapter_manager.create_prompt_adapter_manager(
                    self.model))

1217
1218
        if self.vllm_config.compilation_config.level ==\
            CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
1219
1220
            backend = self.vllm_config.compilation_config.init_backend(
                self.vllm_config)
1221
1222
1223
            self.model = torch.compile(
                self.model,
                fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
1224
                backend=backend)
1225

1226
1227
1228
    def get_model(self) -> nn.Module:
        return self.model

1229
1230
1231
1232
1233
1234
    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
1235
        from vllm.model_executor.model_loader import ShardedStateLoader
1236
1237
1238
1239
1240
1241
1242
        ShardedStateLoader.save_model(
            self.model,
            path,
            pattern=pattern,
            max_size=max_size,
        )

1243
1244
1245
1246
    def save_tensorized_model(
        self,
        tensorizer_config: TensorizerConfig,
    ) -> None:
1247
        from vllm.model_executor.model_loader import TensorizerLoader
1248
1249
1250
1251
1252
        TensorizerLoader.save_model(
            self.model,
            tensorizer_config=tensorizer_config,
        )

1253
1254
    def get_max_block_per_batch(self) -> int:
        block_size = self.block_size
1255
        return (self.max_seq_len_to_capture + block_size - 1) // block_size
1256

1257
    def _prepare_model_input_tensors(
1258
1259
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
Mor Zusman's avatar
Mor Zusman committed
1260
        finished_requests_ids: Optional[List[str]] = None
1261
1262
1263
1264
    ) -> TModelInputForGPU:
        """Helper method to prepare the model input based on a given sequence
        group. Prepares metadata needed for the base model forward pass but not
        metadata for possible additional steps, e.g., sampling.
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275

        The API assumes seq_group_metadata_list is sorted by prefill -> decode.

        The result tensors and data structure also batches input in prefill
        -> decode order. For example,

        - input_tokens[:num_prefill_tokens] contains prefill tokens.
        - input_tokens[num_prefill_tokens:] contains decode tokens.

        If cuda graph is required, this API automatically pads inputs.
        """
1276
        self.builder.prepare(finished_requests_ids)
1277
        for seq_group_metadata in seq_group_metadata_list:
1278
1279
1280
1281
1282
1283
            try:
                self.builder.add_seq_group(seq_group_metadata)
            except Exception as e:
                # Raise an exception that tracks the ID of the bad request
                raise InputProcessingError(seq_group_metadata.request_id,
                                           str(e)) from e
1284

1285
        self.builder.reset_cached_inter_data()
1286

1287
        return self.builder.build()  # type: ignore
1288

1289
1290
1291
1292
1293
1294
1295
1296
    @contextmanager
    def set_in_profile_run(self):
        self.in_profile_run = True
        try:
            yield
        finally:
            self.in_profile_run = False

1297
1298
    @torch.inference_mode()
    def profile_run(self) -> None:
1299
1300
1301
1302
1303
        max_num_batched_tokens = \
            self.scheduler_config.max_num_batched_tokens
        max_num_seqs = self.scheduler_config.max_num_seqs
        self._dummy_run(max_num_batched_tokens, max_num_seqs)

1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
    def _add_dummy_loras(self, num_loras: int) -> list[LoRARequest]:
        assert num_loras > 0
        assert self.lora_manager is not None

        dummy_lora_requests: list[LoRARequest] = []
        with self.lora_manager.dummy_lora_cache():
            for idx in range(num_loras):
                lora_id = idx + 1
                dummy_lora_request = LoRARequest(
                    lora_name=f"warmup_{lora_id}",
                    lora_int_id=lora_id,
                    lora_path="/not/a/real/path",
                )
                self.lora_manager.add_dummy_lora(dummy_lora_request,
                                                 rank=LORA_WARMUP_RANK)
                dummy_lora_requests.append(dummy_lora_request)
        return dummy_lora_requests

    def _remove_dummy_loras(self):
        # Remove dummy loras.
        assert self.lora_manager is not None
        self.remove_all_loras()

1327
1328
1329
    def _dummy_run(self,
                   max_num_batched_tokens: int,
                   max_num_seqs: int = 1) -> None:
1330
1331
1332
1333
        with self.set_in_profile_run():
            # Enable top-k sampling to reflect the accurate memory usage.
            sampling_params = \
                SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
1334

1335
            # This represents the maximum number of different requests
1336
1337
1338
1339
            # that will have unique loras, and therefore the max amount of
            # memory consumption. Create dummy lora request copies from the
            # lora request passed in, which contains a lora from the lora
            # warmup path.
1340
1341
1342
            dummy_lora_requests: List[LoRARequest] = []
            dummy_lora_requests_per_seq: List[LoRARequest] = []
            if self.lora_config:
1343
1344
1345
1346
1347
1348
1349
                dummy_lora_requests = self._add_dummy_loras(
                    self.lora_config.max_loras)
                assert len(dummy_lora_requests) == self.lora_config.max_loras
                dummy_lora_requests_per_seq = [
                    dummy_lora_requests[idx % len(dummy_lora_requests)]
                    for idx in range(max_num_seqs)
                ]
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382

            # Profile memory usage with max_num_sequences sequences and the
            # total number of tokens equal to max_num_batched_tokens.
            seqs: List[SequenceGroupMetadata] = []
            # Additional GPU memory may be needed for multi-modal encoding,
            # which needs to be accounted for when calculating the GPU blocks
            # for vLLM blocker manager.
            # To exercise the worst scenario for GPU memory consumption,
            # the number of seqs (batch_size) is chosen to maximize the number
            # of images processed.

            max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
                self.model_config)
            if max_mm_tokens > 0:
                max_num_seqs_orig = max_num_seqs
                max_num_seqs = min(max_num_seqs,
                                   max_num_batched_tokens // max_mm_tokens)
                if max_num_seqs < 1:
                    expr = (f"min({max_num_seqs_orig}, "
                            f"{max_num_batched_tokens} // {max_mm_tokens})")
                    logger.warning(
                        "Computed max_num_seqs (%s) to be less than 1. "
                        "Setting it to the minimum value of 1.", expr)
                    max_num_seqs = 1

            batch_size = 0
            for group_id in range(max_num_seqs):
                seq_len = (max_num_batched_tokens // max_num_seqs +
                           (group_id < max_num_batched_tokens % max_num_seqs))
                batch_size += seq_len

                dummy_data = self.input_registry \
                    .dummy_data_for_profiling(self.model_config,
1383
1384
                                              seq_len,
                                              self.mm_registry)
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423

                seq = SequenceGroupMetadata(
                    request_id=str(group_id),
                    is_prompt=True,
                    seq_data={group_id: dummy_data.seq_data},
                    sampling_params=sampling_params,
                    block_tables=None,
                    lora_request=dummy_lora_requests_per_seq[group_id]
                    if dummy_lora_requests_per_seq else None,
                    multi_modal_data=dummy_data.multi_modal_data,
                    multi_modal_placeholders=dummy_data.
                    multi_modal_placeholders,
                )
                seqs.append(seq)

            # Run the model with the dummy inputs.
            num_layers = self.model_config.get_num_layers(self.parallel_config)
            # use an empty tensor instead of `None`` to force Dynamo to pass
            # it by reference, rather by specializing on the value ``None``.
            # the `dtype` argument does not matter, and we use `float32` as
            # a placeholder (it has wide hardware support).
            # it is important to create tensors inside the loop, rather than
            # multiplying the list, to avoid Dynamo from treating them as
            # tensor aliasing.
            kv_caches = [
                torch.tensor([], dtype=torch.float32, device=self.device)
                for _ in range(num_layers)
            ]
            finished_requests_ids = [seq.request_id for seq in seqs]
            model_input = self.prepare_model_input(
                seqs, finished_requests_ids=finished_requests_ids)
            intermediate_tensors = None
            if not get_pp_group().is_first_rank:
                intermediate_tensors = \
                    self.model.make_empty_intermediate_tensors(
                    batch_size=batch_size,
                    dtype=self.model_config.dtype,
                    device=self.device)

1424
1425
1426
1427
            # Disable KV Scale Calculation for dummy data during profile run
            if model_input.attn_metadata is not None:
                model_input.attn_metadata.enable_kv_scales_calculation = False

1428
1429
            self.execute_model(model_input, kv_caches, intermediate_tensors)
            torch.cuda.synchronize()
1430
            if self.lora_config:
1431
1432
                self._remove_dummy_loras()

1433
            return
1434

1435
    def remove_all_loras(self):
1436
1437
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
1438
        self.lora_manager.remove_all_adapters()
1439

1440
    def set_active_loras(self, lora_requests: Set[LoRARequest],
1441
1442
1443
                         lora_mapping: LoRAMapping) -> None:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
1444
        self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
1445
1446
1447
1448

    def add_lora(self, lora_request: LoRARequest) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
1449
        return self.lora_manager.add_adapter(lora_request)
1450
1451
1452
1453

    def remove_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
1454
        return self.lora_manager.remove_adapter(lora_id)
1455
1456
1457
1458

    def pin_lora(self, lora_id: int) -> bool:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
1459
        return self.lora_manager.pin_adapter(lora_id)
1460
1461
1462
1463

    def list_loras(self) -> Set[int]:
        if not self.lora_manager:
            raise RuntimeError("LoRA is not enabled.")
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
        return self.lora_manager.list_adapters()

    def remove_all_prompt_adapters(self):
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        self.prompt_adapter_manager.remove_all_adapters()

    def set_active_prompt_adapters(
            self, prompt_adapter_requests: Set[PromptAdapterRequest],
            prompt_adapter_mapping: PromptAdapterMapping) -> None:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        self.prompt_adapter_manager.set_active_adapters(
            prompt_adapter_requests, prompt_adapter_mapping)

    def add_prompt_adapter(
            self, prompt_adapter_request: PromptAdapterRequest) -> bool:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        return self.prompt_adapter_manager.add_adapter(prompt_adapter_request)

    def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id)

    def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id)

    def list_prompt_adapters(self) -> Set[int]:
        if not self.prompt_adapter_manager:
            raise RuntimeError("PromptAdapter is not enabled.")
        return self.prompt_adapter_manager.list_adapters()
1499

1500
    @torch.inference_mode()
1501
    def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
        """Cuda graph capture a model.

        Note that CUDA graph's performance gain is negligible if number
        of batched tokens are larger than 200. And since CUDA graph
        requires fixed sized tensors, supporting large/variable batch
        size requires high GPU memory overhead. Thus, vLLM only captures
        decoding requests. Mixed batch (chunked prefill + decoding) or
        prefill requests are not captured.

        Since it is used for decoding-only, it assumes there's only 1 token
        per sequence in the batch.
        """
1514
        assert not self.model_config.enforce_eager
1515
        logger.info("Capturing cudagraphs for decoding. This may lead to "
1516
1517
                    "unexpected consequences if the model is not static. To "
                    "run the model in eager mode, set 'enforce_eager=True' or "
1518
1519
                    "use '--enforce-eager' in the CLI. "
                    "If out-of-memory error occurs during cudagraph capture,"
1520
1521
1522
                    " consider decreasing `gpu_memory_utilization` or "
                    "switching to eager mode. You can also reduce the "
                    "`max_num_seqs` as needed to decrease memory usage.")
1523
        start_time = time.perf_counter()
1524
        start_free_gpu_memory = torch.cuda.mem_get_info()[0]
1525
1526

        # Prepare dummy inputs. These will be reused for all batch sizes.
1527
        max_batch_size = self.max_batchsize_to_capture
1528
1529
1530
1531
1532
1533
        input_tokens = torch.zeros(max_batch_size,
                                   dtype=torch.long,
                                   device=self.device)
        input_positions = torch.zeros(max_batch_size,
                                      dtype=torch.long,
                                      device=self.device)
1534
1535
1536
1537
        inputs_embeds = torch.zeros(
            (max_batch_size, self.model_config.get_hidden_size()),
            dtype=self.model_config.dtype,
            device=self.device)
1538
        if self.model_config.uses_mrope:
1539
1540
            input_positions = torch.tile(input_positions,
                                         (3, 1)).cuda(device=self.device)
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
        # Prepare dummy previous_hidden_states only if needed by the model.
        # This is used by draft models such as EAGLE.
        previous_hidden_states = None
        if "previous_hidden_states" in inspect.signature(
                self.model.forward).parameters:
            previous_hidden_states = torch.empty(
                [max_batch_size,
                 self.model_config.get_hidden_size()],
                dtype=self.model_config.dtype,
                device=self.device)

1552
1553
1554
1555
1556
1557
        intermediate_inputs = None
        if not get_pp_group().is_first_rank:
            intermediate_inputs = self.model.make_empty_intermediate_tensors(
                batch_size=max_batch_size,
                dtype=self.model_config.dtype,
                device=self.device)
1558

1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
        dummy_lora_id: Optional[int] = None
        dummy_lora_request: LoRARequest = []
        if self.lora_config:
            # The goal is to capture the LoRA kernels in cuda graphs.
            # for this purpose, as single dummy lora is sufficient.
            dummy_lora_requests = self._add_dummy_loras(num_loras=1)
            assert len(dummy_lora_requests) == 1
            dummy_lora_request = dummy_lora_requests[0]
            dummy_lora_id = dummy_lora_request.lora_int_id

1569
1570
        with self.attn_state.graph_capture(max_batch_size), graph_capture(
                self.device) as graph_capture_context:
1571
1572
            # NOTE: Capturing the largest batch size first may help reduce the
            # memory usage of CUDA graph.
1573
1574
            for virtual_engine in range(
                    self.parallel_config.pipeline_parallel_size):
1575
1576
1577
1578
1579
                # We need to not only iterate over batch sizes, but also whether
                # to use inputs_embeds or not, hence we use the cartesian
                # product.
                cudagraph_capture_sizes = self.vllm_config.compilation_config\
                    .cudagraph_capture_sizes
1580
1581
1582
                cudagraph_inputs_embeds = ((
                    True, False) if self.model_config.enable_prompt_embeds else
                                           (False, ))
1583
                compilation_cases = itertools.product(
1584
                    cudagraph_capture_sizes,
1585
1586
1587
1588
1589
1590
1591
1592
                    cudagraph_inputs_embeds,
                )
                # Only rank 0 should print progress bar during capture
                if get_tensor_model_parallel_rank() == 0:
                    compilation_cases = tqdm(
                        list(compilation_cases),
                        desc="Capturing CUDA graph shapes")
                for batch_size, use_inputs_embeds in compilation_cases:
1593
1594
                    attn_metadata = (
                        self.attn_state.graph_capture_get_metadata_for_batch(
1595
1596
                            batch_size,
                            is_encoder_decoder_model=self.model_config.
1597
                            is_encoder_decoder))
1598
1599
                    # Disable KV Scale Calculation for graph capture
                    attn_metadata.enable_kv_scales_calculation = False
1600
1601
                    if self.lora_config:
                        lora_mapping = LoRAMapping(
1602
1603
                            **dict(index_mapping=[dummy_lora_id] * batch_size,
                                   prompt_mapping=[dummy_lora_id] * batch_size,
1604
                                   is_prefill=False))
1605
1606
                        self.set_active_loras(set([dummy_lora_request]),
                                              lora_mapping)
1607

1608
1609
1610
1611
1612
1613
1614
                    if self.prompt_adapter_config:
                        prompt_adapter_mapping = PromptAdapterMapping(
                            [-1] * batch_size,
                            [-1] * batch_size,
                        )
                        self.set_active_prompt_adapters(
                            set(), prompt_adapter_mapping)
1615
                    graph_runner = CUDAGraphRunner(
1616
                        self.model, self.attn_backend.get_name(),
1617
                        self.attn_state.graph_clone(batch_size),
1618
                        self.model_config.is_encoder_decoder)
1619

Mor Zusman's avatar
Mor Zusman committed
1620
1621
                    capture_inputs = {
                        "input_ids":
1622
                        input_tokens[:batch_size],
1623
1624
1625
                        "inputs_embeds":
                        inputs_embeds[:batch_size]
                        if use_inputs_embeds else None,
Mor Zusman's avatar
Mor Zusman committed
1626
                        "positions":
1627
                        input_positions[..., :batch_size],
Mor Zusman's avatar
Mor Zusman committed
1628
                        "intermediate_inputs":
1629
1630
                        intermediate_inputs[:batch_size]
                        if intermediate_inputs is not None else None,
Mor Zusman's avatar
Mor Zusman committed
1631
                        "kv_caches":
1632
                        kv_caches[virtual_engine],
Mor Zusman's avatar
Mor Zusman committed
1633
                        "attn_metadata":
1634
                        attn_metadata,
Mor Zusman's avatar
Mor Zusman committed
1635
1636
1637
1638
1639
                        "memory_pool":
                        self.graph_memory_pool,
                        "stream":
                        graph_capture_context.stream
                    }
1640
1641
1642
1643
1644
                    if previous_hidden_states is not None:
                        capture_inputs[
                            "previous_hidden_states"] = previous_hidden_states[:
                                                                               batch_size]

1645
                    if self.has_inner_state:
Mor Zusman's avatar
Mor Zusman committed
1646
1647
1648
1649
1650
1651
                        # Only used by Mamba-based models CUDA graph atm (Jamba)
                        capture_inputs.update({
                            "seqlen_agnostic_capture_inputs":
                            self.model.get_seqlen_agnostic_capture_inputs(
                                batch_size)
                        })
1652
                    if self.model_config.is_encoder_decoder:
1653
1654
1655
1656
1657
                        # add the additional inputs to capture for
                        # encoder-decoder models.
                        self._update_inputs_to_capture_for_enc_dec_model(
                            capture_inputs)

1658
1659
                    with set_forward_context(attn_metadata, self.vllm_config,
                                             virtual_engine):
1660
                        graph_runner.capture(**capture_inputs)
1661
                    self.graph_memory_pool = graph_runner.graph.pool()
1662
1663
                    self.graph_runners[virtual_engine][(
                        batch_size, use_inputs_embeds)] = graph_runner
1664

1665
1666
1667
        if self.lora_config:
            self._remove_dummy_loras()

1668
        end_time = time.perf_counter()
1669
        end_free_gpu_memory = torch.cuda.mem_get_info()[0]
1670
        elapsed_time = end_time - start_time
1671
        cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
1672
        # This usually takes < 10 seconds.
1673
1674
        logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
                    elapsed_time, cuda_graph_size / GiB_bytes)
1675

1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
    def _update_inputs_to_capture_for_enc_dec_model(self,
                                                    capture_inputs: Dict[str,
                                                                         Any]):
        """
        Updates the set of input tensors needed for CUDA graph capture in an
        encoder-decoder model.

        This method modifies the provided `capture_inputs` dictionary by
        adding tensors specific to encoder-decoder specific models that
        need to be captured for CUDA Graph replay.
        """
        # During the decode phase encoder_input_ids and encoder_positions are
        # unset. Do the same thing for graph capture.
1689
1690
1691
1692
1693
1694
        capture_inputs["encoder_input_ids"] = torch.tensor([],
                                                           dtype=torch.long,
                                                           device=self.device)
        capture_inputs["encoder_positions"] = torch.tensor([],
                                                           dtype=torch.long,
                                                           device=self.device)
1695

1696
1697
1698
1699
    @property
    def vocab_size(self) -> int:
        return self.model_config.get_vocab_size()

1700

1701
1702
1703
1704
1705
1706
class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
    """
    GPU model runner with sampling step.
    """
    _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
        ModelInputForGPUWithSamplingMetadata)
1707
    _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder
1708
1709
1710
1711
1712

    def make_model_input_from_broadcasted_tensor_dict(
        self,
        tensor_dict: Dict[str, Any],
    ) -> ModelInputForGPUWithSamplingMetadata:
1713
        model_input = \
1714
1715
1716
            ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
                tensor_dict,
                attn_backend=self.attn_backend,
1717
1718
            )
        return model_input
1719
1720
1721
1722

    def prepare_model_input(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
1723
        virtual_engine: int = 0,
1724
        finished_requests_ids: Optional[List[str]] = None,
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
    ) -> ModelInputForGPUWithSamplingMetadata:
        """Prepare the model input based on a given sequence group, including
        metadata for the sampling step.

        The API assumes seq_group_metadata_list is sorted by prefill -> decode.

        The result tensors and data structure also batches input in prefill
        -> decode order. For example,

        - input_tokens[:num_prefill_tokens] contains prefill tokens.
        - input_tokens[num_prefill_tokens:] contains decode tokens.

        If cuda graph is required, this API automatically pads inputs.
        """
        model_input = self._prepare_model_input_tensors(
Mor Zusman's avatar
Mor Zusman committed
1740
            seq_group_metadata_list, finished_requests_ids)
1741
1742
1743
1744
1745
1746
        if get_pp_group().is_last_rank:
            # Sampling metadata is only required for the final pp group
            generators = self.get_generators(finished_requests_ids)
            sampling_metadata = SamplingMetadata.prepare(
                seq_group_metadata_list, model_input.seq_lens,
                model_input.query_lens, self.device, self.pin_memory,
1747
                generators, self.sampling_metadata_cache)
1748
1749
        else:
            sampling_metadata = None
1750
1751
1752
1753
        is_prompt = (seq_group_metadata_list[0].is_prompt
                     if seq_group_metadata_list else None)
        return dataclasses.replace(model_input,
                                   sampling_metadata=sampling_metadata,
1754
1755
                                   is_prompt=is_prompt,
                                   virtual_engine=virtual_engine)
1756
1757
1758
1759
1760
1761

    @torch.inference_mode()
    def execute_model(
        self,
        model_input: ModelInputForGPUWithSamplingMetadata,
        kv_caches: List[torch.Tensor],
1762
        intermediate_tensors: Optional[IntermediateTensors] = None,
1763
        num_steps: int = 1,
1764
        **kwargs,
1765
    ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
1766
1767
1768
        if num_steps > 1:
            raise ValueError("num_steps > 1 is not supported in ModelRunner")

1769
1770
1771
1772
1773
1774
        if self.lora_config:
            assert model_input.lora_requests is not None
            assert model_input.lora_mapping is not None
            self.set_active_loras(model_input.lora_requests,
                                  model_input.lora_mapping)

1775
1776
1777
1778
1779
1780
1781
        if self.prompt_adapter_config:
            assert model_input.prompt_adapter_requests is not None
            assert model_input.prompt_adapter_mapping is not None
            self.set_active_prompt_adapters(
                model_input.prompt_adapter_requests,
                model_input.prompt_adapter_mapping)

1782
        self.attn_state.begin_forward(model_input)
1783

1784
1785
1786
1787
        # Currently cuda graph is only supported by the decode phase.
        assert model_input.attn_metadata is not None
        prefill_meta = model_input.attn_metadata.prefill_metadata
        decode_meta = model_input.attn_metadata.decode_metadata
1788
1789
1790
        # TODO(andoorve): We can remove this once all
        # virtual engines share the same kv cache.
        virtual_engine = model_input.virtual_engine
1791
        previous_hidden_states = kwargs.get("previous_hidden_states")
1792
1793
1794
        if prefill_meta is None and decode_meta.use_cuda_graph:
            assert model_input.input_tokens is not None
            graph_batch_size = model_input.input_tokens.shape[0]
1795
1796
1797
            use_inputs_embeds = model_input.inputs_embeds is not None
            model_executable = self.graph_runners[virtual_engine][(
                graph_batch_size, use_inputs_embeds)]
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
            if previous_hidden_states is not None:
                previous_hidden_states = torch.cat([
                    previous_hidden_states,
                    torch.empty([
                        graph_batch_size - previous_hidden_states.shape[0],
                        *previous_hidden_states.shape[1:]
                    ],
                                dtype=previous_hidden_states.dtype,
                                device=previous_hidden_states.device)
                ])
1808
1809
1810
        else:
            model_executable = self.model

1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
        # Receive KV cache in distributed KV cache transfer setting
        # In disagg prefill setting, it will also recv hidden states and bypass
        # model forwarding
        # In KV cache database setting, it will change the model input so that
        # we can skip prefilling on tokens that successfully received KV caches
        # NOTE: The receive operation is blocking
        bypass_model_exec = False
        if self.need_recv_kv(model_input, kv_caches):
            hidden_or_intermediate_states, bypass_model_exec, model_input = \
                get_kv_transfer_group().recv_kv_caches_and_hidden_states(
                    # model is used to know which layer the current worker
                    # is working on, so that we can receive KV for only those
                    # layers.
                    model_executable,
                    model_input,
                    kv_caches=kv_caches
                )

1829
        multi_modal_kwargs = model_input.multi_modal_kwargs or {}
Mor Zusman's avatar
Mor Zusman committed
1830
1831
1832
        seqlen_agnostic_kwargs = {
            "finished_requests_ids": model_input.finished_requests_ids,
            "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
1833
        } if self.has_inner_state else {}
1834
1835
1836
        model_kwargs = {}
        if previous_hidden_states is not None:
            model_kwargs["previous_hidden_states"] = previous_hidden_states
1837
1838
1839
1840
1841
1842
        if (self.observability_config is not None
                and self.observability_config.collect_model_forward_time):
            model_forward_start = torch.cuda.Event(enable_timing=True)
            model_forward_end = torch.cuda.Event(enable_timing=True)
            model_forward_start.record()

1843
1844
        if not bypass_model_exec:
            with set_forward_context(model_input.attn_metadata,
1845
                                     self.vllm_config, virtual_engine):
1846
1847
                hidden_or_intermediate_states = model_executable(
                    input_ids=model_input.input_tokens,
1848
                    inputs_embeds=model_input.inputs_embeds,
1849
1850
1851
1852
                    positions=model_input.input_positions,
                    intermediate_tensors=intermediate_tensors,
                    **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
                                                 device=self.device),
1853
1854
1855
                    **seqlen_agnostic_kwargs,
                    **model_kwargs,
                )
1856

1857
1858
1859
1860
        if (self.observability_config is not None
                and self.observability_config.collect_model_forward_time):
            model_forward_end.record()

1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
        # Sending KV cache in distributed KV cache transfer setting
        # NOTE: the send operation is non-blocking
        if self.need_send_kv(model_input, kv_caches):
            get_kv_transfer_group().send_kv_caches_and_hidden_states(
                # model_executable is used to know which layer the current
                # worker is working on, so that we can send KV for only those
                # layers.
                model_executable,
                model_input,
                kv_caches,
                hidden_or_intermediate_states,
            )

1874
1875
        # Compute the logits in the last pipeline stage.
        if not get_pp_group().is_last_rank:
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
            if (self.is_driver_worker
                    and hidden_or_intermediate_states is not None
                    and isinstance(hidden_or_intermediate_states,
                                   IntermediateTensors)
                    and self.observability_config is not None
                    and self.observability_config.collect_model_forward_time):
                model_forward_end.synchronize()
                model_forward_time = model_forward_start.elapsed_time(
                    model_forward_end)
                orig_model_forward_time = 0.0
                if intermediate_tensors is not None:
                    orig_model_forward_time = intermediate_tensors.tensors.get(
                        "model_forward_time", torch.tensor(0.0)).item()
                hidden_or_intermediate_states.tensors["model_forward_time"] = (
                    torch.tensor(model_forward_time + orig_model_forward_time))
1891
1892
1893
            return hidden_or_intermediate_states

        logits = self.model.compute_logits(hidden_or_intermediate_states,
1894
1895
1896
                                           model_input.sampling_metadata)

        if not self.is_driver_worker:
1897
            return []
1898

1899
1900
        if model_input.async_callback is not None:
            model_input.async_callback()
1901

1902
        # Sample the next token.
1903
1904
1905
1906
1907
        assert isinstance(self.sampler, Sampler)
        orig_include_gpu_probs_tensor = self.sampler.include_gpu_probs_tensor
        if model_input.inputs_embeds is not None:
            self.sampler.include_gpu_probs_tensor = True

1908
        output: SamplerOutput = self.sampler(
1909
1910
1911
            logits=logits,
            sampling_metadata=model_input.sampling_metadata,
        )
1912
1913
1914
1915
1916
1917
        if (self.observability_config is not None
                and self.observability_config.collect_model_forward_time
                and output is not None):
            model_forward_end.synchronize()
            model_forward_time = model_forward_start.elapsed_time(
                model_forward_end)
1918
1919
1920
1921
            orig_model_forward_time = 0.0
            if intermediate_tensors is not None:
                orig_model_forward_time = intermediate_tensors.tensors.get(
                    "model_forward_time", torch.tensor(0.0)).item()
1922
1923
1924
1925
            # If there are multiple workers, we are still tracking the latency
            # from the start time of the driver worker to the end time of the
            # driver worker. The model forward time will then end up covering
            # the communication time as well.
1926
1927
            output.model_forward_time = (orig_model_forward_time +
                                         model_forward_time)
1928

1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
        if model_input.inputs_embeds is not None:
            self.sampler.include_gpu_probs_tensor = \
                orig_include_gpu_probs_tensor
            if output.sampled_token_ids is not None:
                output.sampled_token_embeds = self.model.get_input_embeddings(
                    output.sampled_token_ids.squeeze(1))

                for token_embed, sequence_group_output in zip(
                        output.sampled_token_embeds, output.outputs):
                    assert len(sequence_group_output.samples) == 1
                    sequence_group_output.samples[0].output_embed = token_embed

1941
1942
        if self.return_hidden_states:
            # we only need to pass hidden states of most recent token
1943
1944
            assert model_input.sampling_metadata is not None
            indices = model_input.sampling_metadata.selected_token_indices
1945
            if model_input.is_prompt:
1946
1947
                hidden_states = hidden_or_intermediate_states.index_select(
                    0, indices)
1948
                output.prefill_hidden_states = hidden_or_intermediate_states
1949
            elif decode_meta.use_cuda_graph:
1950
1951
1952
                hidden_states = hidden_or_intermediate_states[:len(indices)]
            else:
                hidden_states = hidden_or_intermediate_states
1953

1954
1955
            output.hidden_states = hidden_states

1956
        return [output]
1957

1958
1959
1960
1961
1962
1963
    def need_recv_kv(self, model_input, kv_caches) -> bool:
        """Check if we need to receive kv-cache from the other worker.
        We need to receive KV when
            1. current vLLM instance is KV cache consumer/decode vLLM instance
            2. this batch is not a profiling run
            3. this batch is a prefill run
1964

1965
1966
1967
1968
1969
        Args:
            model_input: input to the model executable
            kv_caches: vLLM's paged memory
        """

youkaichao's avatar
youkaichao committed
1970
1971
1972
        if self.vllm_config.kv_transfer_config is None:
            return False

1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
        prefill_meta = model_input.attn_metadata.prefill_metadata

        # check if the current run is profiling
        is_profile_run = (kv_caches[0].numel() == 0)
        # check if the current run is prefill
        is_prefill_run = prefill_meta is not None

        return self.vllm_config.kv_transfer_config.is_kv_consumer and (
            not is_profile_run) and is_prefill_run

    def need_send_kv(self, model_input, kv_caches) -> bool:
        """Check if we need to send kv-cache to the other worker.
        We need to send KV when
            1. current vLLM instance is KV cache producer/prefill vLLM instance
            2. this batch is not a profiling run
            3. this batch is a prefill run
1989

1990
1991
1992
1993
1994
        Args:
            model_input: input to the model executable
            kv_caches: vLLM's paged memory
        """

youkaichao's avatar
youkaichao committed
1995
1996
1997
        if self.vllm_config.kv_transfer_config is None:
            return False

1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
        prefill_meta = model_input.attn_metadata.prefill_metadata

        # check if the current run is profiling
        is_profile_run = (kv_caches[0].numel() == 0)
        # check if the current run is prefill
        is_prefill_run = prefill_meta is not None

        return self.vllm_config.kv_transfer_config.is_kv_producer and (
            not is_profile_run) and is_prefill_run

2008

2009
2010
2011
# NOTE: this is nn.Module so the profiler can properly capture/group
#  kernels calls made within the graph
class CUDAGraphRunner(nn.Module):
2012

2013
    def __init__(self, model: nn.Module, backend_name: str,
2014
                 attn_state: AttentionState, is_encoder_decoder_model: bool):
2015
        super().__init__()
2016
        self.model = model
2017
        self.backend_name = backend_name
2018
        self.attn_state = attn_state
2019

2020
2021
2022
        self.input_buffers: Dict[str, torch.Tensor] = {}
        self.output_buffers: Dict[str, torch.Tensor] = {}

2023
        self._graph: Optional[torch.cuda.CUDAGraph] = None
2024
        self._is_encoder_decoder_model = is_encoder_decoder_model
2025
2026
2027
2028
2029
2030

    @property
    def graph(self):
        assert self._graph is not None
        return self._graph

2031
2032
2033
    def capture(
        self,
        input_ids: torch.Tensor,
2034
        inputs_embeds: Optional[torch.Tensor],
2035
        positions: torch.Tensor,
2036
        intermediate_inputs: Optional[IntermediateTensors],
2037
2038
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
2039
2040
        memory_pool: Optional[Tuple[int, int]],
        stream: torch.cuda.Stream,
2041
        **kwargs,
2042
    ):
2043
        assert self._graph is None
2044
        # Run the model a few times without capturing the graph.
2045
2046
        # This is to make sure that the captured graph does not include the
        # kernel launches for initial benchmarking (e.g., Triton autotune).
2047
        # Note one iteration is not enough for torch.compile
2048
2049
        for _ in range(_NUM_WARMUP_ITERS):
            self.model(
2050
                input_ids=input_ids,
2051
                inputs_embeds=inputs_embeds,
2052
2053
                positions=positions,
                intermediate_tensors=intermediate_inputs,
2054
2055
                **kwargs,
            )
2056
2057
        # Wait for the warm up operations to finish before proceeding with
        # Graph Capture.
2058
2059
2060
2061
        torch.cuda.synchronize()
        # Capture the graph.
        self._graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
2062
            output_hidden_or_intermediate_states = self.model(
2063
                input_ids=input_ids,
2064
2065
2066
                **({
                    "inputs_embeds": inputs_embeds,
                } if inputs_embeds is not None else {}),
2067
2068
                positions=positions,
                intermediate_tensors=intermediate_inputs,
2069
                **kwargs,
2070
            )
2071
2072
2073

            if isinstance(output_hidden_or_intermediate_states, torch.Tensor):
                hidden_or_intermediate_states = weak_ref_tensor(
2074
                    output_hidden_or_intermediate_states)
2075
2076
2077
2078
2079
2080
2081
2082
            elif isinstance(output_hidden_or_intermediate_states,
                            IntermediateTensors):
                hidden_or_intermediate_states = IntermediateTensors(
                    tensors={
                        key: weak_ref_tensor(value)
                        for key, value in
                        output_hidden_or_intermediate_states.tensors.items()
                    })
2083
2084

            del output_hidden_or_intermediate_states
2085
            # make sure `output_hidden_or_intermediate_states` is deleted
2086
2087
            # in the graph's memory pool
            gc.collect()
2088
2089
2090
        torch.cuda.synchronize()

        # Save the input and output buffers.
2091
        self.input_buffers = {
2092
2093
            "input_ids":
            input_ids,
2094
2095
2096
            **({
                "inputs_embeds": inputs_embeds,
            } if inputs_embeds is not None else {}),
2097
2098
2099
2100
2101
2102
            "positions":
            positions,
            "kv_caches":
            kv_caches,
            **self.attn_state.get_graph_input_buffers(
                attn_metadata, self._is_encoder_decoder_model),
2103
2104
            **kwargs,
        }
2105
2106
2107
2108
2109
2110
2111
2112
        if intermediate_inputs is not None:
            self.input_buffers.update(intermediate_inputs.tensors)
        if get_pp_group().is_last_rank:
            self.output_buffers = {
                "hidden_states": hidden_or_intermediate_states
            }
        else:
            self.output_buffers = hidden_or_intermediate_states
2113
2114
2115
2116

    def forward(
        self,
        input_ids: torch.Tensor,
2117
        inputs_embeds: Optional[torch.Tensor],
2118
        positions: torch.Tensor,
2119
        intermediate_tensors: Optional[IntermediateTensors],
2120
        **kwargs,
2121
    ) -> torch.Tensor:
2122
        attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
2123
2124

        # Copy the input tensors to the input buffers.
2125
        self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
2126
        if positions is not None:
2127
2128
2129
2130
2131
            # in some case like MLA, it will reuse positions in metadata
            # but truncate them to the original size
            # so the shape is not padded, we need to copy partial only
            self.input_buffers["positions"][:positions.shape[0]].copy_(
                positions, non_blocking=True)
2132
2133
2134
        if inputs_embeds is not None:
            self.input_buffers["inputs_embeds"][:inputs_embeds.shape[0]].copy_(
                inputs_embeds, non_blocking=True)
2135

2136
        if self.backend_name != "NO_ATTENTION":
2137
2138
2139
            self.input_buffers["slot_mapping"].copy_(
                attn_metadata.slot_mapping, non_blocking=True)

2140
2141
        self.attn_state.prepare_graph_input_buffers(
            self.input_buffers, attn_metadata, self._is_encoder_decoder_model)
2142

Mor Zusman's avatar
Mor Zusman committed
2143
2144
2145
        if "seqlen_agnostic_capture_inputs" in self.input_buffers:
            self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
                                                      **kwargs)
2146
2147
2148
2149
2150

        if "previous_hidden_states" in self.input_buffers:
            self.input_buffers["previous_hidden_states"].copy_(
                kwargs["previous_hidden_states"], non_blocking=True)

2151
2152
        if intermediate_tensors is not None:
            for key in intermediate_tensors.tensors:
2153
                if key != "model_execute_time" and key != "model_forward_time":
2154
2155
                    self.input_buffers[key].copy_(intermediate_tensors[key],
                                                  non_blocking=True)
2156
2157
2158
2159
2160
2161
        if self._is_encoder_decoder_model:
            self.input_buffers["encoder_input_ids"].copy_(
                kwargs['encoder_input_ids'], non_blocking=True)
            self.input_buffers["encoder_positions"].copy_(
                kwargs['encoder_positions'], non_blocking=True)

2162
2163
2164
        # Run the graph.
        self.graph.replay()
        # Return the output tensor.
2165
2166
2167
2168
        if get_pp_group().is_last_rank:
            return self.output_buffers["hidden_states"]

        return self.output_buffers