utils.py 12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections import defaultdict
4
from dataclasses import dataclass
5
from typing import TYPE_CHECKING, Optional
6

7
8
import torch

9
from vllm.attention.backends.abstract import AttentionBackend
10
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
11
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
12
from vllm.model_executor.models.utils import extract_layer_index
13
from vllm.multimodal.cache import processor_only_cache_from_config
14
from vllm.multimodal.registry import MultiModalRegistry
15
from vllm.platforms import current_platform
16
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
17
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
18
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
19

20
21
22
if TYPE_CHECKING:
    from vllm.attention.layer import Attention

23

24
25
26
27
28
29
30
31
32
33
34
35
36
37
class MultiModalBudget:
    """Helper class to calculate budget information for multi-modal models."""

    def __init__(
        self,
        model_config: ModelConfig,
        scheduler_config: SchedulerConfig,
        mm_registry: MultiModalRegistry,
    ) -> None:
        super().__init__()

        self.model_config = model_config
        self.scheduler_config = scheduler_config
        self.mm_registry = mm_registry
38
39
        self.cache = cache = processor_only_cache_from_config(
            model_config, mm_registry)
40

41
42
43
        self.max_model_len = model_config.max_model_len
        self.max_num_reqs = scheduler_config.max_num_seqs

44
45
        self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config,
                                                              cache=cache)
46
47

        max_tokens_by_modality = mm_registry \
48
49
            .get_max_tokens_per_item_by_nonzero_modality(model_config,
                                                         cache=cache)
50
51
52
53

        encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
            scheduler_config,
            max_tokens_by_modality,
54
55
        )

56
        self.encoder_compute_budget = encoder_compute_budget
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        self.encoder_cache_size = encoder_cache_size

        max_items_per_prompt_by_modality = dict[str, int]()
        max_items_per_batch_by_modality = dict[str, int]()

        for modality, max_tokens in max_tokens_by_modality.items():
            (
                max_items_per_prompt,
                max_items_per_batch,
            ) = self.get_max_items(modality, max_tokens)

            max_items_per_prompt_by_modality[modality] = max_items_per_prompt
            max_items_per_batch_by_modality[modality] = max_items_per_batch

        self.max_tokens_by_modality = max_tokens_by_modality
        self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality
        self.max_items_per_batch_by_modality = max_items_per_batch_by_modality

75
    def get_modality_with_max_tokens(self) -> str:
76
        max_tokens_by_modality = self.max_tokens_by_modality
77
        modality, _ = max(max_tokens_by_modality.items(), key=lambda x: x[1])
78

79
        return modality
80
81

    def get_encoder_budget(self) -> int:
82
        return min(self.encoder_compute_budget, self.encoder_cache_size)
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

    def get_max_items(
        self,
        modality: str,
        max_tokens_per_item: int,
    ) -> tuple[int, int]:
        if max_tokens_per_item == 0:
            return 0, 0

        # Check how many items of this modality can be supported by
        # the encoder budget.
        encoder_budget = self.get_encoder_budget()

        # TODO: handle encoder-decoder models once we support them.
        if encoder_budget == 0:
            return 0, 0

        max_encoder_items_per_batch = encoder_budget // max_tokens_per_item

        # Check how many items of this modality can be supported by
        # the decoder budget.
        mm_limit = self.mm_limits[modality]

        max_items_per_prompt = max(
            1,
            min(mm_limit, self.max_model_len // max_tokens_per_item),
        )

        scheduler_config = self.scheduler_config
        max_num_reqs = self.max_num_reqs

        if not scheduler_config.enable_chunked_prefill:
            max_num_reqs = min(
                max_num_reqs,
                scheduler_config.max_num_batched_tokens // max_tokens_per_item,
            )

        max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt

        max_items_per_batch = max(
            1,
            min(max_encoder_items_per_batch, max_decoder_items_per_batch),
        )

        return max_items_per_prompt, max_items_per_batch


130
131
132
@dataclass
class AttentionGroup:
    backend: type[AttentionBackend]
133
    metadata_builders: list[AttentionMetadataBuilder]
134
    layer_names: list[str]
135
    kv_cache_spec: KVCacheSpec
136

137
138
139
140
141
142
143
144
    def get_metadata_builder(self,
                             ubatch_id: Optional[int] = None
                             ) -> AttentionMetadataBuilder:
        if ubatch_id is None:
            return self.metadata_builders[0]
        assert len(self.metadata_builders) > ubatch_id
        return self.metadata_builders[ubatch_id]

145

146
def sanity_check_mm_encoder_outputs(
147
    mm_embeddings: MultiModalEmbeddings,
148
149
150
151
    expected_num_items: int,
) -> None:
    """
    Perform sanity checks for the result of
152
    [`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`][].
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    """
    assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), (
        "Expected multimodal embeddings to be a list/tuple of 2D tensors, "
        f"or a single 3D tensor, but got {type(mm_embeddings)} "
        "instead. This is most likely due to incorrect implementation "
        "of the model's `get_multimodal_embeddings` method.")

    assert len(mm_embeddings) == expected_num_items, (
        "Expected number of multimodal embeddings to match number of "
        f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
        "instead. This is most likely due to incorrect implementation "
        "of the model's `get_multimodal_embeddings` method.")

    assert all(e.ndim == 2 for e in mm_embeddings), (
        "Expected multimodal embeddings to be a sequence of 2D tensors, "
        f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
        "instead. This is most likely due to incorrect implementation "
        "of the model's `get_multimodal_embeddings` method.")
171
172
173
174
175
176
177
178
179
180


def scatter_mm_placeholders(
    embeds: torch.Tensor,
    is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
    """
    Scatter the multimodal embeddings into a contiguous tensor that represents
    the placeholder tokens.

181
    [`vllm.multimodal.processing.PromptUpdateDetails.is_embed`][].
182
183
184

    Args:
        embeds: The multimodal embeddings.
185
            Shape: `(num_embeds, embed_dim)`
186
        is_embed: A boolean mask indicating which positions in the placeholder
187
188
            tokens need to be filled with multimodal embeddings.
            Shape: `(num_placeholders, num_embeds)`
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    """
    if is_embed is None:
        return embeds

    placeholders = embeds.new_full(
        (is_embed.shape[0], embeds.shape[-1]),
        fill_value=torch.nan,
    )
    placeholders[is_embed] = embeds
    return placeholders


def gather_mm_placeholders(
    placeholders: torch.Tensor,
    is_embed: Optional[torch.Tensor],
) -> torch.Tensor:
    """
    Reconstructs the embeddings from the placeholder tokens.

208
209
    This is the operation of [`scatter_mm_placeholders`]
    [vllm.v1.worker.utils.scatter_mm_placeholders].
210
211
212
213
214
    """
    if is_embed is None:
        return placeholders

    return placeholders[is_embed]
215
216


217
def add_kv_sharing_layers_to_kv_cache_groups(
218
219
    shared_kv_cache_layers: dict[str, str],
    kv_cache_groups: list[KVCacheGroupSpec],
220
    runner_only_attn_layers: Optional[set[str]] = None,
221
222
223
224
225
226
227
228
229
230
231
232
233
234
) -> None:
    """
    Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
    for layers that do not allocate its own KV cache, based on the mapping in
    `shared_kv_cache_layers`. Adds these layers to the corresponding KV cache
    group, which is needed to ensure that attention metadata is assigned later.

    Args:
        shared_kv_cache_layers: Layer pairings for cross-layer KV sharing.
            If an Attention layer `layer_name` is in the keys of this dict, it
            means this layer will perform attention using the keys and values
            from the KV cache of `shared_kv_cache_layers[layer_name]`.
        kv_cache_groups: The KV cache groups of the model.
    """
235
236
237
238
    layer_to_kv_cache_group: dict[str, KVCacheGroupSpec] = {}
    for kv_cache_group in kv_cache_groups:
        for layer_name in kv_cache_group.layer_names:
            layer_to_kv_cache_group[layer_name] = kv_cache_group
239
240

    for layer_name, target_layer_name in shared_kv_cache_layers.items():
241
242
        tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name]
        tgt_kv_cache_group.layer_names.append(layer_name)
243

244
245
246
        if runner_only_attn_layers is not None:
            runner_only_attn_layers.add(layer_name)

247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265

def bind_kv_cache(
    kv_caches: dict[str, torch.Tensor],
    forward_context: dict[str, "Attention"],
    runner_kv_caches: list[torch.Tensor],
) -> None:
    """
    Bind the allocated KV cache to both ModelRunner and forward context so
    that the KV cache can be used in the forward pass.

    This function:
      1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
         kv_caches.
      2) Associates each attention layer in the `forward_context` with its
         corresponding KV cache in kv_caches.

    Args:
        kv_caches: The allocated kv_caches with layer names as keys.
        forward_context: The global forward context containing all Attention
266
            layers with layer names as keys.
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        runner_kv_caches: The kv_cache declared by ModelRunner.
    """
    # Bind kv_caches to ModelRunner
    assert len(runner_kv_caches) == 0

    # Convert kv_caches dict to a list of tensors in the order of layer_index.
    index2name = defaultdict(list)
    for layer_name in kv_caches:
        index2name[extract_layer_index(layer_name)].append(layer_name)

    for layer_index in sorted(index2name.keys()):
        layer_names = index2name[layer_index]
        if len(layer_names) > 1:
            # One typical case is encoder-decoder model, e.g., bart.
            # The cross attention and self attention in the same decoder layer
            # has different layer_name but the same layer_index.
283
284
285
286

            # TODO - analyze where runner_kv_caches is used and the right
            # way to ensure it properly reflects multiple attention layers
            # in the same decoder block.
287
            if current_platform.is_cuda() or current_platform.is_xpu():
288
289
290
291
292
293
                # We know that the GPU runner is not impacted by this
                # case. Some test code depends on runner_kv_caches, but
                # not in a way that's impacted by ignoring this.
                pass
            else:
                raise NotImplementedError
294
295
296
297
298
299
300
        layer_name = layer_names[0]
        runner_kv_caches.append(kv_caches[layer_name])

    # Bind kv_caches to forward context
    for layer_name, kv_cache in kv_caches.items():
        # NOTE: Use list because of v0 PP virtual engine.
        forward_context[layer_name].kv_cache = [kv_cache]
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325


def is_residual_scattered_for_sp(vllm_config: VllmConfig,
                                 num_input_tokens: int) -> bool:
    """Check if the residual tensor is scattered for sequence parallelism.

    The residual tensor is scattered across tensor parallel ranks when sequence
    parallelism and tensor parallelism is enabled, and the number of
    input tokens is one of the compilation sizes.
    """
    if not vllm_config.compilation_config.pass_config.\
        enable_sequence_parallelism:
        return False

    tp = vllm_config.parallel_config.tensor_parallel_size

    if tp == 1:
        return False

    # When sequence parallelism is enabled, we always pad num_input_tokens
    # to be a multiple of tensor_parallel_size (tp) earlier.
    assert num_input_tokens % tp == 0

    # Currently, SP is only enabled for static size fx graphs.
    return (num_input_tokens in vllm_config.compilation_config.compile_sizes)