registry.py 11.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Mapping
4
from dataclasses import dataclass
5
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast
6

7
from vllm.config.multimodal import BaseDummyOptions
8
from vllm.config.observability import ObservabilityConfig
9
from vllm.logger import init_logger
10
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
11

12
from .cache import BaseMultiModalProcessorCache
13
14
15
16
17
18
19
20
21
22
from .processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
)
from .profiling import (
    BaseDummyInputsBuilder,
    DummyDecoderData,
    MultiModalProfiler,
)
23

24
if TYPE_CHECKING:
25
    from vllm.config import ModelConfig, ObservabilityConfig
26
    from vllm.model_executor.models.interfaces import SupportsMultiModal
27

28
29
logger = init_logger(__name__)

30
N = TypeVar("N", bound=type["SupportsMultiModal"])
31
32
_I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
33
34


35
class ProcessingInfoFactory(Protocol[_I_co]):
36
37
38
39
40
    """
    Constructs a
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
    instance from the context.
    """
41
42
43
44

    def __call__(
        self,
        ctx: InputProcessingContext,
45
    ) -> _I_co: ...
46
47


48
class DummyInputsBuilderFactory(Protocol[_I]):  # type: ignore[misc]
49
    """
50
51
52
    Constructs a
    [`BaseDummyInputsBuilder`][vllm.multimodal.profiling.BaseDummyInputsBuilder]
    instance from the context.
53
54
    """

55
    def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: ...
56
57


58
class MultiModalProcessorFactory(Protocol[_I]):  # type: ignore[misc]
59
60
61
62
63
    """
    Constructs a
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
    instance from the context.
    """
64
65
66
67
68

    def __call__(
        self,
        info: _I,
        dummy_inputs: BaseDummyInputsBuilder[_I],
69
        *,
70
        cache: BaseMultiModalProcessorCache | None = None,
71
    ) -> BaseMultiModalProcessor[_I]: ...
72

73

74
75
76
77
78
79
80
81
82
83
@dataclass(frozen=True)
class _ProcessorFactories(Generic[_I]):
    info: ProcessingInfoFactory[_I]
    processor: MultiModalProcessorFactory[_I]
    dummy_inputs: DummyInputsBuilderFactory[_I]

    def build_processor(
        self,
        ctx: InputProcessingContext,
        *,
84
        cache: BaseMultiModalProcessorCache | None = None,
85
86
87
88
89
90
    ):
        info = self.info(ctx)
        dummy_inputs_builder = self.dummy_inputs(info)
        return self.processor(info, dummy_inputs_builder, cache=cache)


91
92
class MultiModalRegistry:
    """
93
    A registry that dispatches data processing according to the model.
94
95
    """

96
97
98
    def _extract_mm_options(
        self,
        model_config: "ModelConfig",
99
    ) -> Mapping[str, BaseDummyOptions] | None:
100
101
102
103
104
105
106
107
108
109
110
111
        """
        Extract multimodal dummy options from model config.

        Returns None if no configurable options are found, otherwise returns
        a mapping of modality names to their dummy options.
        """
        if not model_config.multimodal_config:
            return None

        mm_options = {
            m: opt
            for m in model_config.multimodal_config.limit_per_prompt
112
            if (opt := model_config.multimodal_config.get_dummy_options(m)) is not None
113
114
115
116
        }

        return mm_options if len(mm_options) > 0 else None

117
    def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
118
119
        """
        Checks if the model supports multimodal inputs.
120
121
        Returns True if the model is multimodal with any non-zero supported
        modalities, otherwise returns False, effectively running in
122
123
124
125
126
        text-only mode.
        """
        if not model_config.is_multimodal_model:
            return False

127
        info = self._create_processing_info(model_config, tokenizer=None)
128
        supported_modalities = info.get_supported_mm_limits()
129
130
131
132
133

        mm_config = model_config.get_multimodal_config()

        # Check if all supported modalities have limit == 0
        if all(
134
135
136
            mm_config.get_limit_per_prompt(modality) == 0
            for modality in supported_modalities
        ):
137
138
            logger.info_once(
                "All limits of multimodal modalities supported by the model "
139
140
                "are set to 0, running in text-only mode."
            )
141
142
143
144
            return False

        return True

145
146
    def get_max_tokens_per_item_by_modality(
        self,
147
        model_config: "ModelConfig",
148
        *,
149
        cache: BaseMultiModalProcessorCache | None = None,
150
        profiler_limits: Mapping[str, int] | None = None,
151
        observability_config: ObservabilityConfig | None = None,
152
153
    ) -> Mapping[str, int]:
        """
154
        Get the maximum number of tokens per data item from each modality based
155
        on underlying model configuration.
156
        """
157
158
        if not model_config.is_multimodal_model:
            return {}
159

160
161
162
        processor = self.create_processor(
            model_config, observability_config, cache=cache
        )
163
        profiler: MultiModalProfiler = MultiModalProfiler(processor)
164
165

        seq_len = model_config.max_model_len
166
167
168
        profiler_limits = (
            profiler.get_mm_limits() if profiler_limits is None else profiler_limits
        )
169

170
        return profiler.get_mm_max_tokens(
171
            seq_len,
172
            {modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
173
        )
174

175
176
    def get_mm_limits_per_prompt(
        self,
177
        model_config: "ModelConfig",
178
        *,
179
        cache: BaseMultiModalProcessorCache | None = None,
180
        observability_config: ObservabilityConfig | None = None,
181
182
183
184
    ) -> Mapping[str, int]:
        """
        Get the maximum number of multi-modal input instances for each modality
        that are allowed per prompt for a model class.
185
        """
186
187
        if not model_config.is_multimodal_model:
            return {}
188

189
190
191
        processor = self.create_processor(
            model_config, observability_config, cache=cache
        )
192
        profiler: MultiModalProfiler = MultiModalProfiler(processor)
193
        return profiler.get_mm_limits()
194
195
196

    def register_processor(
        self,
197
198
199
200
        processor: MultiModalProcessorFactory[_I],
        *,
        info: ProcessingInfoFactory[_I],
        dummy_inputs: DummyInputsBuilderFactory[_I],
201
202
    ):
        """
203
204
        Register a multi-modal processor to a model class. The processor
        is constructed lazily, hence a factory method should be passed.
205
206
207
208
209
210

        When the model receives multi-modal data, the provided function is
        invoked to transform the data into a dictionary of model inputs.
        """

        def wrapper(model_cls: N) -> N:
211
            if "_processor_factory" in model_cls.__dict__:
212
                logger.warning(
213
                    "Model class %s already has a multi-modal processor "
214
                    "registered to %s. It is overwritten by the new one.",
215
216
217
                    model_cls,
                    self,
                )
218

219
            model_cls._processor_factory = _ProcessorFactories(
220
221
222
223
                info=info,
                dummy_inputs=dummy_inputs,
                processor=processor,
            )
224
225
226
227
228

            return model_cls

        return wrapper

229
    def _get_model_cls(self, model_config: "ModelConfig") -> "SupportsMultiModal":
230
231
232
233
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
234
235
        assert hasattr(model_cls, "_processor_factory")
        return cast("SupportsMultiModal", model_cls)
236

237
238
239
    def _create_processing_ctx(
        self,
        model_config: "ModelConfig",
240
        observability_config: "ObservabilityConfig | None" = None,
241
242
243
244
245
        tokenizer: TokenizerLike | None = None,
    ) -> InputProcessingContext:
        if tokenizer is None and not model_config.skip_tokenizer_init:
            tokenizer = cached_tokenizer_from_config(model_config)

246
247
248
        return InputProcessingContext(
            model_config, tokenizer, observability_config=observability_config
        )
249

250
251
    def _create_processing_info(
        self,
252
        model_config: "ModelConfig",
253
        observability_config: "ObservabilityConfig | None" = None,
254
        *,
255
        tokenizer: TokenizerLike | None = None,
256
    ) -> BaseProcessingInfo:
257
        model_cls = self._get_model_cls(model_config)
258
        factories = model_cls._processor_factory
259
        ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
260
261
        return factories.info(ctx)

262
263
    def create_processor(
        self,
264
        model_config: "ModelConfig",
265
        observability_config: "ObservabilityConfig | None" = None,
266
        *,
267
        tokenizer: TokenizerLike | None = None,
268
        cache: BaseMultiModalProcessorCache | None = None,
269
    ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
270
271
272
        """
        Create a multi-modal processor for a specific model and tokenizer.
        """
273
274
275
        if not model_config.is_multimodal_model:
            raise ValueError(f"{model_config.model} is not a multimodal model")

276
        model_cls = self._get_model_cls(model_config)
277
        factories = model_cls._processor_factory
278

279
        ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
280

281
        return factories.build_processor(ctx, cache=cache)
282
283
284

    def get_decoder_dummy_data(
        self,
285
        model_config: "ModelConfig",
286
        seq_len: int,
287
        mm_counts: Mapping[str, int] | None = None,
288
        *,
289
        cache: BaseMultiModalProcessorCache | None = None,
290
        observability_config: ObservabilityConfig | None = None,
291
292
293
294
    ) -> DummyDecoderData:
        """
        Create dummy data for profiling the memory usage of a model.

295
        The model is identified by `model_config`.
296
        """
297
298
299
        processor = self.create_processor(
            model_config, observability_config, cache=cache
        )
300
301
302
303
304
        profiler: MultiModalProfiler = MultiModalProfiler(processor)

        # Extract configurable options from multimodal config.
        # Only include modalities that use advanced option types so legacy
        # count-only behavior remains unchanged.
305
        mm_options = self._extract_mm_options(model_config)
306

307
        dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts, mm_options)
308
309
310
311
312
313

        # Having more tokens is over-conservative but otherwise fine
        token_ids = dummy_data.prompt_token_ids
        if len(token_ids) < seq_len:
            raise AssertionError(
                f"Expected at least {seq_len} dummy tokens for profiling, "
314
315
                f"but found {len(token_ids)} tokens instead."
            )
316
317
318

        return dummy_data

319
    def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
320
321
322
323
324
        """
        Get the maximum length of the encoder input for encoder-decoder models.
        """
        if not model_config.is_encoder_decoder:
            return 0
325
        max_tokens = self.get_max_tokens_per_item_by_modality(model_config)
326
327
328
329
330
        if not max_tokens:
            # TODO - this function assumes encoder-decoder models are
            # multimodal. This will need to change when adding support for more
            # than whisper.
            return 0
331
        assert len(max_tokens) == 1, (
332
333
            "Encoder-decoder models are expected "
            "to implement the multimodal interface with at most one modality."
334
        )
335
336
337

        first_modality = next(iter(max_tokens))
        return max_tokens[first_modality]