utils.py 13 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import math
4
from collections import defaultdict
5
from dataclasses import dataclass, field
6

7
8
import torch

9
from vllm.config import CacheConfig, VllmConfig
10
from vllm.logger import init_logger
11
from vllm.model_executor.layers.attention import Attention
12
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
13
from vllm.model_executor.models.utils import extract_layer_index
14
from vllm.multimodal.registry import MultiModalRegistry
15
from vllm.platforms import current_platform
16
from vllm.utils.mem_utils import MemorySnapshot, format_gib
17
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadataBuilder
18
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
19
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
20

21
22
logger = init_logger(__name__)

23

24
25
26
27
28
class MultiModalBudget:
    """Helper class to calculate budget information for multi-modal models."""

    def __init__(
        self,
29
        vllm_config: VllmConfig,
30
31
32
33
        mm_registry: MultiModalRegistry,
    ) -> None:
        super().__init__()

34
35
        self.model_config = model_config = vllm_config.model_config
        self.scheduler_config = scheduler_config = vllm_config.scheduler_config
36
        self.mm_registry = mm_registry
37
        self.cache = cache = mm_registry.processor_only_cache_from_config(vllm_config)
38

39
        self.max_model_len = model_config.max_model_len
40
41
        self.max_num_reqs = scheduler_config.max_num_seqs

42
        self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config)
43

44
        max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
45
            model_config,
46
47
            cache=cache,
            profiler_limits=self.mm_limits,
48
        )
49
50
51
52

        encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
            scheduler_config,
            max_tokens_by_modality,
53
54
        )

55
        self.encoder_compute_budget = encoder_compute_budget
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        self.encoder_cache_size = encoder_cache_size

        max_items_per_prompt_by_modality = dict[str, int]()
        max_items_per_batch_by_modality = dict[str, int]()

        for modality, max_tokens in max_tokens_by_modality.items():
            (
                max_items_per_prompt,
                max_items_per_batch,
            ) = self.get_max_items(modality, max_tokens)

            max_items_per_prompt_by_modality[modality] = max_items_per_prompt
            max_items_per_batch_by_modality[modality] = max_items_per_batch

        self.max_tokens_by_modality = max_tokens_by_modality
        self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality
        self.max_items_per_batch_by_modality = max_items_per_batch_by_modality

74
    def get_modality_with_max_tokens(self) -> str:
75
        max_tokens_by_modality = self.max_tokens_by_modality
76
        modality, _ = max(max_tokens_by_modality.items(), key=lambda x: x[1])
77

78
        return modality
79
80

    def get_encoder_budget(self) -> int:
81
        return min(self.encoder_compute_budget, self.encoder_cache_size)
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127

    def get_max_items(
        self,
        modality: str,
        max_tokens_per_item: int,
    ) -> tuple[int, int]:
        if max_tokens_per_item == 0:
            return 0, 0

        # Check how many items of this modality can be supported by
        # the encoder budget.
        encoder_budget = self.get_encoder_budget()

        # TODO: handle encoder-decoder models once we support them.
        if encoder_budget == 0:
            return 0, 0

        max_encoder_items_per_batch = encoder_budget // max_tokens_per_item

        # Check how many items of this modality can be supported by
        # the decoder budget.
        mm_limit = self.mm_limits[modality]

        max_items_per_prompt = max(
            1,
            min(mm_limit, self.max_model_len // max_tokens_per_item),
        )

        scheduler_config = self.scheduler_config
        max_num_reqs = self.max_num_reqs

        if not scheduler_config.enable_chunked_prefill:
            max_num_reqs = min(
                max_num_reqs,
                scheduler_config.max_num_batched_tokens // max_tokens_per_item,
            )

        max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt

        max_items_per_batch = max(
            1,
            min(max_encoder_items_per_batch, max_decoder_items_per_batch),
        )

        return max_items_per_prompt, max_items_per_batch

128
129
130
131
    def reset_cache(self) -> None:
        if self.cache is not None:
            self.cache.clear_cache()

132

133
134
135
136
@dataclass
class AttentionGroup:
    backend: type[AttentionBackend]
    layer_names: list[str]
137
    kv_cache_spec: KVCacheSpec
138
    kv_cache_group_id: int
139
    # When ubatching is enabled we will have a metadata builder for each ubatch
140
    # so that if they use internal persistent buffers for cudagraphs, and they
141
142
143
144
    # won't have to worry about conflicting with the other ubatches.
    metadata_builders: list[AttentionMetadataBuilder] = field(
        default_factory=lambda: []
    )
145

146
147
148
149
150
    def create_metadata_builders(
        self,
        vllm_config,
        device,
        kernel_block_size: int | None,
151
        num_metadata_builders: int = 1,
152
153
154
155
156
157
158
159
160
161
162
163
164
    ):
        kv_cache_spec_builder = (
            self.kv_cache_spec.copy_with_new_block_size(kernel_block_size)
            if kernel_block_size is not None
            else self.kv_cache_spec
        )
        self.metadata_builders = [
            self.backend.get_builder_cls()(
                kv_cache_spec_builder,
                self.layer_names,
                vllm_config,
                device,
            )
165
166
167
            for _ in range(num_metadata_builders)
        ]

168
    def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder:
169
170
171
        assert len(self.metadata_builders) > ubatch_id
        return self.metadata_builders[ubatch_id]

172

173
def sanity_check_mm_encoder_outputs(
174
    mm_embeddings: MultiModalEmbeddings,
175
176
177
178
    expected_num_items: int,
) -> None:
    """
    Perform sanity checks for the result of
179
    [`vllm.model_executor.models.SupportsMultiModal.embed_multimodal`][].
180
181
182
183
184
    """
    assert isinstance(mm_embeddings, (list, tuple, torch.Tensor)), (
        "Expected multimodal embeddings to be a list/tuple of 2D tensors, "
        f"or a single 3D tensor, but got {type(mm_embeddings)} "
        "instead. This is most likely due to incorrect implementation "
185
        "of the model's `embed_multimodal` method."
186
    )
187
188
189
190
191

    assert len(mm_embeddings) == expected_num_items, (
        "Expected number of multimodal embeddings to match number of "
        f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
        "instead. This is most likely due to incorrect implementation "
192
        "of the model's `embed_multimodal` method."
193
    )
194
195
196
197
198

    assert all(e.ndim == 2 for e in mm_embeddings), (
        "Expected multimodal embeddings to be a sequence of 2D tensors, "
        f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
        "instead. This is most likely due to incorrect implementation "
199
        "of the model's `embed_multimodal` method."
200
    )
201
202


203
def request_memory(init_snapshot: MemorySnapshot, cache_config: CacheConfig) -> int:
204
205
206
207
    """
    Calculate the amount of memory required by vLLM, then validate
    that the current amount of free memory is sufficient for that.
    """
208
209
210
    requested_memory = math.ceil(
        init_snapshot.total_memory * cache_config.gpu_memory_utilization
    )
211
212
213
214

    if init_snapshot.free_memory < requested_memory:
        raise ValueError(
            f"Free memory on device {init_snapshot.device_} "
215
216
            f"({format_gib(init_snapshot.free_memory)}/"
            f"{format_gib(init_snapshot.total_memory)} GiB) on startup "
217
218
            f"is less than desired GPU memory utilization "
            f"({cache_config.gpu_memory_utilization}, "
219
            f"{format_gib(requested_memory)} GiB). Decrease GPU memory "
220
221
222
223
224
225
            f"utilization or reduce GPU memory used by other processes."
        )

    return requested_memory


226
def add_kv_sharing_layers_to_kv_cache_groups(
227
228
    shared_kv_cache_layers: dict[str, str],
    kv_cache_groups: list[KVCacheGroupSpec],
229
    runner_only_attn_layers: set[str] | None = None,
230
231
232
233
234
235
236
237
238
239
240
241
242
243
) -> None:
    """
    Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
    for layers that do not allocate its own KV cache, based on the mapping in
    `shared_kv_cache_layers`. Adds these layers to the corresponding KV cache
    group, which is needed to ensure that attention metadata is assigned later.

    Args:
        shared_kv_cache_layers: Layer pairings for cross-layer KV sharing.
            If an Attention layer `layer_name` is in the keys of this dict, it
            means this layer will perform attention using the keys and values
            from the KV cache of `shared_kv_cache_layers[layer_name]`.
        kv_cache_groups: The KV cache groups of the model.
    """
244
245
246
247
    layer_to_kv_cache_group: dict[str, KVCacheGroupSpec] = {}
    for kv_cache_group in kv_cache_groups:
        for layer_name in kv_cache_group.layer_names:
            layer_to_kv_cache_group[layer_name] = kv_cache_group
248
249

    for layer_name, target_layer_name in shared_kv_cache_layers.items():
250
251
        tgt_kv_cache_group = layer_to_kv_cache_group[target_layer_name]
        tgt_kv_cache_group.layer_names.append(layer_name)
252

253
254
255
        if runner_only_attn_layers is not None:
            runner_only_attn_layers.add(layer_name)

256
257
258

def bind_kv_cache(
    kv_caches: dict[str, torch.Tensor],
259
    forward_context: dict[str, Attention],
260
    runner_kv_caches: list[torch.Tensor],
261
    num_attn_module: int = 1,
262
263
264
265
266
267
268
269
270
271
272
273
274
275
) -> None:
    """
    Bind the allocated KV cache to both ModelRunner and forward context so
    that the KV cache can be used in the forward pass.

    This function:
      1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
         kv_caches.
      2) Associates each attention layer in the `forward_context` with its
         corresponding KV cache in kv_caches.

    Args:
        kv_caches: The allocated kv_caches with layer names as keys.
        forward_context: The global forward context containing all Attention
276
            layers with layer names as keys.
277
278
279
280
281
282
283
284
        runner_kv_caches: The kv_cache declared by ModelRunner.
    """
    # Bind kv_caches to ModelRunner
    assert len(runner_kv_caches) == 0

    # Convert kv_caches dict to a list of tensors in the order of layer_index.
    index2name = defaultdict(list)
    for layer_name in kv_caches:
285
        index2name[extract_layer_index(layer_name, num_attn_module)].append(layer_name)
286
287
288
289
290
291
292

    for layer_index in sorted(index2name.keys()):
        layer_names = index2name[layer_index]
        if len(layer_names) > 1:
            # One typical case is encoder-decoder model, e.g., bart.
            # The cross attention and self attention in the same decoder layer
            # has different layer_name but the same layer_index.
293
294
295
296

            # TODO - analyze where runner_kv_caches is used and the right
            # way to ensure it properly reflects multiple attention layers
            # in the same decoder block.
297
298
299
300
301
302
            if (
                current_platform.is_cuda_alike()
                or current_platform.is_xpu()
                or current_platform.is_cpu()
            ):
                # We know that the GPU / CPU runner is not impacted by this
303
304
305
306
307
                # case. Some test code depends on runner_kv_caches, but
                # not in a way that's impacted by ignoring this.
                pass
            else:
                raise NotImplementedError
308
309
        for layer_name in layer_names:
            runner_kv_caches.append(kv_caches[layer_name])
310
311
312
313
314

    # Bind kv_caches to forward context
    for layer_name, kv_cache in kv_caches.items():
        # NOTE: Use list because of v0 PP virtual engine.
        forward_context[layer_name].kv_cache = [kv_cache]
315
316


317
318
319
def is_residual_scattered_for_sp(
    vllm_config: VllmConfig, num_input_tokens: int
) -> bool:
320
321
322
    """Check if the residual tensor is scattered for sequence parallelism.

    The residual tensor is scattered across tensor parallel ranks when sequence
323
324
    parallelism and tensor parallelism is enabled.

325
    This follows the same logic as SequenceParallelismPass.is_applicable_for_range():
326
327
328
    - In full-graph compilation mode (no splitting ops or using inductor graph
      partition), SP is always applied
    - Otherwise, SP is only applied for specific shapes in compile_sizes
329
    """
330
    if not vllm_config.compilation_config.pass_config.enable_sp:
331
332
333
334
335
336
337
338
339
340
341
        return False

    tp = vllm_config.parallel_config.tensor_parallel_size

    if tp == 1:
        return False

    # When sequence parallelism is enabled, we always pad num_input_tokens
    # to be a multiple of tensor_parallel_size (tp) earlier.
    assert num_input_tokens % tp == 0

342
343
344
345
346
    if (
        not vllm_config.compilation_config.splitting_ops
        or vllm_config.compilation_config.use_inductor_graph_partition
    ):
        return True
347
348
349
350
    compile_sizes = vllm_config.compilation_config.compile_sizes
    if compile_sizes is None:
        return False
    return num_input_tokens in compile_sizes