utils.py 11.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections import defaultdict
4
from dataclasses import dataclass
5
from typing import TYPE_CHECKING, Optional
6

7
8
import torch

9
from vllm.attention.backends.abstract import AttentionBackend
10
from vllm.config import ModelConfig, SchedulerConfig
11
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
12
from vllm.model_executor.models.utils import extract_layer_index
13
from vllm.multimodal.cache import processor_only_cache_from_config
14
from vllm.multimodal.registry import MultiModalRegistry
15
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
16
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
17
18
from vllm.v1.kv_cache_interface import KVCacheGroupSpec

19
20
21
if TYPE_CHECKING:
    from vllm.attention.layer import Attention

22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
class MultiModalBudget:
    """Helper class to calculate budget information for multi-modal models."""

    def __init__(
        self,
        model_config: ModelConfig,
        scheduler_config: SchedulerConfig,
        mm_registry: MultiModalRegistry,
    ) -> None:
        super().__init__()

        self.model_config = model_config
        self.scheduler_config = scheduler_config
        self.mm_registry = mm_registry
37
38
        self.cache = cache = processor_only_cache_from_config(
            model_config, mm_registry)
39

40
41
42
        self.max_model_len = model_config.max_model_len
        self.max_num_reqs = scheduler_config.max_num_seqs

43
44
        self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config,
                                                              cache=cache)
45
46

        max_tokens_by_modality = mm_registry \
47
48
            .get_max_tokens_per_item_by_nonzero_modality(model_config,
                                                         cache=cache)
49
50
51
52

        encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
            scheduler_config,
            max_tokens_by_modality,
53
54
        )

55
        self.encoder_compute_budget = encoder_compute_budget
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        self.encoder_cache_size = encoder_cache_size

        max_items_per_prompt_by_modality = dict[str, int]()
        max_items_per_batch_by_modality = dict[str, int]()

        for modality, max_tokens in max_tokens_by_modality.items():
            (
                max_items_per_prompt,
                max_items_per_batch,
            ) = self.get_max_items(modality, max_tokens)

            max_items_per_prompt_by_modality[modality] = max_items_per_prompt
            max_items_per_batch_by_modality[modality] = max_items_per_batch

        self.max_tokens_by_modality = max_tokens_by_modality
        self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality
        self.max_items_per_batch_by_modality = max_items_per_batch_by_modality

74
    def get_modality_with_max_tokens(self) -> str:
75
        max_tokens_by_modality = self.max_tokens_by_modality
76
        modality, _ = max(max_tokens_by_modality.items(), key=lambda x: x[1])
77

78
        return modality
79
80

    def get_encoder_budget(self) -> int:
81
        return min(self.encoder_compute_budget, self.encoder_cache_size)
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

    def get_max_items(
        self,
        modality: str,
        max_tokens_per_item: int,
    ) -> tuple[int, int]:
        if max_tokens_per_item == 0:
            return 0, 0

        # Check how many items of this modality can be supported by
        # the encoder budget.
        encoder_budget = self.get_encoder_budget()

        # TODO: handle encoder-decoder models once we support them.
        if encoder_budget == 0:
            return 0, 0

        max_encoder_items_per_batch = encoder_budget // max_tokens_per_item

        # Check how many items of this modality can be supported by
        # the decoder budget.
        mm_limit = self.mm_limits[modality]

        max_items_per_prompt = max(
            1,
            min(mm_limit, self.max_model_len // max_tokens_per_item),
        )

        scheduler_config = self.scheduler_config
        max_num_reqs = self.max_num_reqs

        if not scheduler_config.enable_chunked_prefill:
            max_num_reqs = min(
                max_num_reqs,
                scheduler_config.max_num_batched_tokens // max_tokens_per_item,
            )

        max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt

        max_items_per_batch = max(
            1,
            min(max_encoder_items_per_batch, max_decoder_items_per_batch),
        )

        return max_items_per_prompt, max_items_per_batch


129
130
131
132
133
134
135
@dataclass
class AttentionGroup:
    backend: type[AttentionBackend]
    metadata_builder: AttentionMetadataBuilder
    layer_names: list[str]


136
def sanity_check_mm_encoder_outputs(
137
    mm_embeddings: MultiModalEmbeddings,
138
139
140
141
    expected_num_items: int,
) -> None:
    """
    Perform sanity checks for the result of
142
    [`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`][].
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    """
    assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), (
        "Expected multimodal embeddings to be a list/tuple of 2D tensors, "
        f"or a single 3D tensor, but got {type(mm_embeddings)} "
        "instead. This is most likely due to incorrect implementation "
        "of the model's `get_multimodal_embeddings` method.")

    assert len(mm_embeddings) == expected_num_items, (
        "Expected number of multimodal embeddings to match number of "
        f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
        "instead. This is most likely due to incorrect implementation "
        "of the model's `get_multimodal_embeddings` method.")

    assert all(e.ndim == 2 for e in mm_embeddings), (
        "Expected multimodal embeddings to be a sequence of 2D tensors, "
        f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
        "instead. This is most likely due to incorrect implementation "
        "of the model's `get_multimodal_embeddings` method.")
161
162
163
164
165
166
167
168
169
170


def scatter_mm_placeholders(
    embeds: torch.Tensor,
    is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
    """
    Scatter the multimodal embeddings into a contiguous tensor that represents
    the placeholder tokens.

171
    [`vllm.multimodal.processing.PromptUpdateDetails.is_embed`][].
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197

    Args:
        embeds: The multimodal embeddings.
          Shape: `(num_embeds, embed_dim)`
        is_embed: A boolean mask indicating which positions in the placeholder
          tokens need to be filled with multimodal embeddings.
          Shape: `(num_placeholders, num_embeds)`
    """
    if is_embed is None:
        return embeds

    placeholders = embeds.new_full(
        (is_embed.shape[0], embeds.shape[-1]),
        fill_value=torch.nan,
    )
    placeholders[is_embed] = embeds
    return placeholders


def gather_mm_placeholders(
    placeholders: torch.Tensor,
    is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
    """
    Reconstructs the embeddings from the placeholder tokens.

198
    This is the operation of [scatter_mm_placeholders][].
199
200
201
202
203
    """
    if is_embed is None:
        return placeholders

    return placeholders[is_embed]
204
205
206
207
208
209


def initialize_kv_cache_for_kv_sharing(
    shared_kv_cache_layers: dict[str, str],
    kv_cache_groups: list[KVCacheGroupSpec],
    kv_caches: dict[str, torch.Tensor],
210
211
    # Optional for now to avoid breaking TPU
    attn_groups: Optional[list[list[AttentionGroup]]] = None,
212
    runner_only_attn_layers: Optional[set[str]] = None,
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
) -> None:
    """
    Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
    for layers that do not allocate its own KV cache, based on the mapping in
    `shared_kv_cache_layers`. Adds these layers to the corresponding KV cache
    group, which is needed to ensure that attention metadata is assigned later.

    Args:
        shared_kv_cache_layers: Layer pairings for cross-layer KV sharing.
            If an Attention layer `layer_name` is in the keys of this dict, it
            means this layer will perform attention using the keys and values
            from the KV cache of `shared_kv_cache_layers[layer_name]`.
        kv_cache_groups: The KV cache groups of the model.
        kv_caches: The allocated kv_caches with layer names as keys.
            Note that layers in shared_kv_cache_layers.keys() are not
            originally included as it only contains layers which have its own
            KV cache allocation.
230
231
232
233
        attn_groups: Optional list of attention groups. Layers in the same KV
            cache group may be placed in different attention groups if they
            have different attention backends.  Currently only provided by 
            GPU model runner.
234
    """
235
236
237
238
239
240
241
242
243
244
245
246
247
    # mapping from layer name to tuple of (kv_cache_group_idx, attn_group_idx)
    layer_to_attn_group_idx: dict[str, tuple[int, int]] = {}
    if attn_groups:
        for kv_cache_group_idx, kv_attn_groups in enumerate(attn_groups):
            for attn_group_idx, attn_group in enumerate(kv_attn_groups):
                for layer_name in attn_group.layer_names:
                    layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx,
                                                           attn_group_idx)
    else:
        for kv_cache_group_idx, kv_cache_group in enumerate(kv_cache_groups):
            for layer_name in kv_cache_group.layer_names:
                # attn group idx default to 0 if not provided
                layer_to_attn_group_idx[layer_name] = (kv_cache_group_idx, 0)
248
249
250

    for layer_name, target_layer_name in shared_kv_cache_layers.items():
        kv_caches[layer_name] = kv_caches[target_layer_name]
251
252
253
254
255
256
257
        kv_cache_group_idx = layer_to_attn_group_idx[target_layer_name][0]
        kv_cache_groups[kv_cache_group_idx].layer_names.append(layer_name)

        if attn_groups:
            attn_group_idx = layer_to_attn_group_idx[target_layer_name][1]
            attn_groups[kv_cache_group_idx][attn_group_idx].layer_names.append(
                layer_name)
258

259
260
261
        if runner_only_attn_layers is not None:
            runner_only_attn_layers.add(layer_name)

262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305

def bind_kv_cache(
    kv_caches: dict[str, torch.Tensor],
    forward_context: dict[str, "Attention"],
    runner_kv_caches: list[torch.Tensor],
) -> None:
    """
    Bind the allocated KV cache to both ModelRunner and forward context so
    that the KV cache can be used in the forward pass.

    This function:
      1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
         kv_caches.
      2) Associates each attention layer in the `forward_context` with its
         corresponding KV cache in kv_caches.

    Args:
        kv_caches: The allocated kv_caches with layer names as keys.
        forward_context: The global forward context containing all Attention
        layers with layer names as keys.
        runner_kv_caches: The kv_cache declared by ModelRunner.
    """
    # Bind kv_caches to ModelRunner
    assert len(runner_kv_caches) == 0

    # Convert kv_caches dict to a list of tensors in the order of layer_index.
    index2name = defaultdict(list)
    for layer_name in kv_caches:
        index2name[extract_layer_index(layer_name)].append(layer_name)

    for layer_index in sorted(index2name.keys()):
        layer_names = index2name[layer_index]
        if len(layer_names) > 1:
            # One typical case is encoder-decoder model, e.g., bart.
            # The cross attention and self attention in the same decoder layer
            # has different layer_name but the same layer_index.
            raise NotImplementedError
        layer_name = layer_names[0]
        runner_kv_caches.append(kv_caches[layer_name])

    # Bind kv_caches to forward context
    for layer_name, kv_cache in kv_caches.items():
        # NOTE: Use list because of v0 PP virtual engine.
        forward_context[layer_name].kv_cache = [kv_cache]