multimodal.py 9.49 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import Mapping
5
from typing import Any, Literal, TypeAlias
6

7
from pydantic import ConfigDict, Field, field_validator, model_validator
8
9
from pydantic.dataclasses import dataclass

10
from vllm.attention.backends.registry import AttentionBackendEnum
11
from vllm.config.utils import config
12
from vllm.utils.hashing import safe_hash
13

14
15
16
17

@dataclass
class BaseDummyOptions:
    """Base options for generating dummy data during profiling."""
18

19
20
21
22
23
24
    count: int = Field(999, ge=0)


@dataclass(config=ConfigDict(extra="forbid"))
class VideoDummyOptions(BaseDummyOptions):
    """Options for generating dummy video data during profiling."""
25

26
27
28
    num_frames: int | None = Field(None, gt=0)
    width: int | None = Field(None, gt=0)
    height: int | None = Field(None, gt=0)
29
30
31
32
33


@dataclass(config=ConfigDict(extra="forbid"))
class ImageDummyOptions(BaseDummyOptions):
    """Options for generating dummy image data during profiling."""
34

35
36
    width: int | None = Field(None, gt=0)
    height: int | None = Field(None, gt=0)
37
38
39
40
41


@dataclass(config=ConfigDict(extra="forbid"))
class AudioDummyOptions(BaseDummyOptions):
    """Options for generating dummy audio data during profiling."""
42

43
    length: int | None = Field(None, gt=0)
44
45


46
47
MMEncoderTPMode = Literal["weights", "data"]
MMCacheType = Literal["shm", "lru"]
48
49
50
DummyOptions: TypeAlias = (
    BaseDummyOptions | VideoDummyOptions | ImageDummyOptions | AudioDummyOptions
)
51
52
53
54
55
56
57


@config
@dataclass
class MultiModalConfig:
    """Controls the behavior of multimodal models."""

58
    limit_per_prompt: dict[str, DummyOptions] = Field(default_factory=dict)
59
60
61
62
63
64
65
66
67
68
    """The maximum number of input items and options allowed per 
        prompt for each modality.
    Defaults to 999 for each modality.

    Legacy format (count only):
        {"image": 16, "video": 2}

    Configurable format (with options):
        {"video": {"count": 1, "num_frames": 32, "width": 512, "height": 512}, 
        "image": {"count": 5, "width": 512, "height": 512}}
69

70
71
72
73
    Mixed format (combining both):
        {"image": 16, "video": {"count": 1, "num_frames": 32, "width": 512, 
        "height": 512}}
    """
74
75
76
77
78
79
80
81
    enable_mm_embeds: bool = False
    """If `True`, enables passing multimodal embeddings:
    for `LLM` class, this refers to tensor inputs under `multi_modal_data`;
    for the OpenAI-compatible server, this refers to chat messages with content
    `"type": "*_embeds"`.

    WARNING: The vLLM engine may crash if incorrect shape of embeddings is passed.
    Only enable this flag for trusted users!"""
82
    mm_processor_kwargs: dict[str, object] | None = None
83
84
85
86
87
88
89
90
    """Arguments to be forwarded to the model's processor for multi-modal data,
    e.g., image processor. Overrides for the multi-modal processor obtained
    from `transformers.AutoProcessor.from_pretrained`.

    The available overrides depend on the model that is being run.

    For example, for Phi-3-Vision:
    `{"num_crops": 4}`."""
91
    mm_processor_cache_gb: float = Field(default=4, ge=0)
92
93
94
95
96
97
98
99
100
101
102
    """The size (in GiB) of the multi-modal processor cache, which is used to
    avoid re-processing past multi-modal inputs.

    This cache is duplicated for each API process and engine core process,
    resulting in a total memory usage of
    `mm_processor_cache_gb * (api_server_count + data_parallel_size)`.

    Set to `0` to disable this cache completely (not recommended)."""
    mm_processor_cache_type: MMCacheType = "lru"
    """Type of cache to use for the multi-modal preprocessor/mapper. If `shm`,
    use shared memory FIFO cache. If `lru`, use mirrored LRU cache."""
103
    mm_shm_cache_max_object_size_mb: int = Field(default=128, ge=0)
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    """Size limit (in MiB) for each object stored in the multi-modal processor
    shared memory cache. Only effective when `mm_processor_cache_type` is
    `"shm"`."""
    mm_encoder_tp_mode: MMEncoderTPMode = "weights"
    """Indicates how to optimize multi-modal encoder inference using tensor
    parallelism (TP).

    - `"weights"`: Within the same vLLM engine, split the weights of
        each layer across TP ranks. (default TP behavior)\n
    - `"data"`: Within the same vLLM engine, split the batched input data
        across TP ranks to process the data in parallel, while hosting
        the full weights on each TP rank.
        This batch-level DP is not to be confused with API request-level
        DP (which is controlled by `--data-parallel-size`).
        This is only supported on a per-model basis and falls back to
        `"weights"` if the encoder does not support DP."""
120
    mm_encoder_attn_backend: AttentionBackendEnum | None = None
121
122
    """Optional override for the multi-modal encoder attention backend when
    using vision transformers. Accepts any value from
123
    `vllm.attention.backends.registry.AttentionBackendEnum` (e.g. `FLASH_ATTN`)."""
124
125
126
127
128
129
130
131
132
133
    interleave_mm_strings: bool = False
    """Enable fully interleaved support for multimodal prompts, while using
    --chat-template-content-format=string."""
    skip_mm_profiling: bool = False
    """When enabled, skips multimodal memory profiling and only profiles with
    language backbone model during engine initialization.

    This reduces engine startup time but shifts the responsibility to users for
    estimating the peak memory usage of the activation of multimodal encoder and
    embedding cache."""
134
    video_pruning_rate: float | None = Field(default=None, ge=0.0, lt=1.0)
135
136
137
138
    """Sets pruning rate for video pruning via Efficient Video Sampling.
    Value sits in range [0;1) and determines fraction of media tokens
    from each video to be pruned.
    """
139

140
141
142
    @field_validator("limit_per_prompt", mode="before")
    @classmethod
    def _validate_limit_per_prompt(
143
        cls, value: dict[str, int | dict[str, int]]
144
    ) -> dict[str, DummyOptions]:
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        for k, v in value.items():
            # Handle legacy format where only count is specified
            if isinstance(v, int):
                v = {"count": v}
            # Convert to the appropriate DummyOptions subclass
            if k == "video":
                value[k] = VideoDummyOptions(**v)
            elif k == "image":
                value[k] = ImageDummyOptions(**v)
            elif k == "audio":
                value[k] = AudioDummyOptions(**v)
            else:
                value[k] = BaseDummyOptions(**v)
        return value

160
161
    @field_validator("mm_encoder_attn_backend", mode="before")
    @classmethod
162
163
164
    def _validate_mm_encoder_attn_backend(
        cls, value: str | AttentionBackendEnum | None
    ) -> AttentionBackendEnum | None:
165
166
167
168
169
170
        if isinstance(value, str) and value.upper() == "XFORMERS":
            raise ValueError(
                "Attention backend 'XFORMERS' has been removed (See PR #29262 for "
                "details). Please select a supported attention backend."
            )

171
        if value is None or isinstance(value, AttentionBackendEnum):
172
173
            return value

174
175
        assert isinstance(value, str), (
            "mm_encoder_attn_backend must be a string or an AttentionBackendEnum."
176
        )
177
        return AttentionBackendEnum[value.upper()]
178

179
180
181
182
183
184
185
186
187
188
189
190
    @model_validator(mode="after")
    def _validate_multimodal_config(self):
        if self.mm_processor_cache_type != "shm" and (
            self.mm_shm_cache_max_object_size_mb
            != MultiModalConfig.mm_shm_cache_max_object_size_mb
        ):
            raise ValueError(
                "'mm_shm_cache_max_object_size_mb' should only be set when "
                "'mm_processor_cache_type' is 'shm'."
            )
        return self

191
192
193
194
195
196
197
198
199
200
201
202
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
203
204
205
206
207
        factors: list[Any] = [
            self.mm_encoder_attn_backend.name
            if self.mm_encoder_attn_backend is not None
            else None
        ]
208
        hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
209
210
211
212
213
        return hash_str

    def get_limit_per_prompt(self, modality: str) -> int:
        """
        Get the maximum number of input items allowed per prompt
214
215
216
217
218
219
220
221
222
        for the given modality (backward compatible).
        """
        limit_data = self.limit_per_prompt.get(modality)

        if limit_data is None:
            # Unspecified modality is set to 999 by default
            return 999
        return limit_data.count

223
    def get_dummy_options(self, modality: str) -> BaseDummyOptions | None:
224
225
226
        """
        Get the configurable dummy data options for a modality.
        Returns None if no options are configured for this modality.
227
        """
228
229
        # All values are now DummyOptions after normalization
        return self.limit_per_prompt.get(modality)
230
231
232
233
234
235
236
237
238
239
240

    def merge_mm_processor_kwargs(
        self,
        inference_kwargs: Mapping[str, object],
    ) -> dict[str, object]:
        """
        Get the keyword arguments to pass to the multi-modal processor
        according to the extra arguments passed during inference.
        """
        kwargs = self.mm_processor_kwargs or {}
        return kwargs | dict(inference_kwargs)
241
242

    def is_multimodal_pruning_enabled(self):
243
        return self.video_pruning_rate is not None and self.video_pruning_rate > 0