registry.py 11.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Mapping
4
from dataclasses import dataclass
5
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast
6

7
from vllm.config.multimodal import BaseDummyOptions
8
from vllm.config.observability import ObservabilityConfig
9
from vllm.logger import init_logger
10
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
11

12
from .cache import BaseMultiModalProcessorCache
13
from .inputs import MultiModalInputs
14
from .processing import (
15
    BaseDummyInputsBuilder,
16
17
18
19
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    InputProcessingContext,
)
20

21
if TYPE_CHECKING:
22
    from vllm.config import ModelConfig, ObservabilityConfig
23
    from vllm.model_executor.models.interfaces import SupportsMultiModal
24

25
26
logger = init_logger(__name__)

27
N = TypeVar("N", bound=type["SupportsMultiModal"])
28
29
_I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
30
31


32
class ProcessingInfoFactory(Protocol[_I_co]):
33
34
35
36
37
    """
    Constructs a
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
    instance from the context.
    """
38
39
40
41

    def __call__(
        self,
        ctx: InputProcessingContext,
42
    ) -> _I_co: ...
43
44


45
class DummyInputsBuilderFactory(Protocol[_I]):  # type: ignore[misc]
46
    """
47
    Constructs a
48
    [`BaseDummyInputsBuilder`][vllm.multimodal.processing.BaseDummyInputsBuilder]
49
    instance from the context.
50
51
    """

52
    def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]: ...
53
54


55
class MultiModalProcessorFactory(Protocol[_I]):  # type: ignore[misc]
56
57
58
59
60
    """
    Constructs a
    [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
    instance from the context.
    """
61
62
63
64
65

    def __call__(
        self,
        info: _I,
        dummy_inputs: BaseDummyInputsBuilder[_I],
66
        *,
67
        cache: BaseMultiModalProcessorCache | None = None,
68
    ) -> BaseMultiModalProcessor[_I]: ...
69

70

71
72
73
74
75
76
77
78
79
80
@dataclass(frozen=True)
class _ProcessorFactories(Generic[_I]):
    info: ProcessingInfoFactory[_I]
    processor: MultiModalProcessorFactory[_I]
    dummy_inputs: DummyInputsBuilderFactory[_I]

    def build_processor(
        self,
        ctx: InputProcessingContext,
        *,
81
        cache: BaseMultiModalProcessorCache | None = None,
82
83
84
85
86
87
    ):
        info = self.info(ctx)
        dummy_inputs_builder = self.dummy_inputs(info)
        return self.processor(info, dummy_inputs_builder, cache=cache)


88
89
class MultiModalRegistry:
    """
90
    A registry that dispatches data processing according to the model.
91
92
    """

93
94
95
    def _extract_mm_options(
        self,
        model_config: "ModelConfig",
96
    ) -> Mapping[str, BaseDummyOptions] | None:
97
98
99
100
101
102
103
104
105
106
107
108
        """
        Extract multimodal dummy options from model config.

        Returns None if no configurable options are found, otherwise returns
        a mapping of modality names to their dummy options.
        """
        if not model_config.multimodal_config:
            return None

        mm_options = {
            m: opt
            for m in model_config.multimodal_config.limit_per_prompt
109
            if (opt := model_config.multimodal_config.get_dummy_options(m)) is not None
110
111
112
113
        }

        return mm_options if len(mm_options) > 0 else None

114
    def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
115
116
        """
        Checks if the model supports multimodal inputs.
117
118
        Returns True if the model is multimodal with any non-zero supported
        modalities, otherwise returns False, effectively running in
119
120
121
122
123
        text-only mode.
        """
        if not model_config.is_multimodal_model:
            return False

124
        info = self._create_processing_info(model_config, tokenizer=None)
125
        supported_modalities = info.get_supported_mm_limits()
126
127
128
129
130

        mm_config = model_config.get_multimodal_config()

        # Check if all supported modalities have limit == 0
        if all(
131
132
133
            mm_config.get_limit_per_prompt(modality) == 0
            for modality in supported_modalities
        ):
134
135
            logger.info_once(
                "All limits of multimodal modalities supported by the model "
136
137
                "are set to 0, running in text-only mode."
            )
138
139
140
141
            return False

        return True

142
143
    def get_max_tokens_per_item_by_modality(
        self,
144
        model_config: "ModelConfig",
145
        *,
146
        cache: BaseMultiModalProcessorCache | None = None,
147
        profiler_limits: Mapping[str, int] | None = None,
148
        observability_config: ObservabilityConfig | None = None,
149
150
    ) -> Mapping[str, int]:
        """
151
        Get the maximum number of tokens per data item from each modality based
152
        on underlying model configuration.
153
        """
154
155
        if not model_config.is_multimodal_model:
            return {}
156

157
158
159
        processor = self.create_processor(
            model_config, observability_config, cache=cache
        )
160

161
162
163
164
165
166
        if profiler_limits is None:
            profiler_limits = processor.allowed_mm_limits

        mm_counts = {
            modality: 1 for modality, limit in profiler_limits.items() if limit > 0
        }
167

168
        max_tokens_per_item = processor.info.get_mm_max_tokens_per_item(
169
            seq_len=model_config.max_model_len,
170
171
172
173
174
175
176
177
178
            mm_counts=mm_counts,
        )
        if max_tokens_per_item is not None:
            return {
                modality: max_tokens
                for modality, max_tokens in max_tokens_per_item.items()
                if mm_counts.get(modality, 0) > 0
            }

179
180
        mm_inputs = self.get_dummy_mm_inputs(
            model_config,
181
            mm_counts=mm_counts,
182
            processor=processor,
183
        )
184

185
186
187
188
189
        return {
            modality: sum(item.get_num_embeds for item in placeholders)
            for modality, placeholders in mm_inputs["mm_placeholders"].items()
        }

190
191
    def get_mm_limits_per_prompt(
        self,
192
        model_config: "ModelConfig",
193
        *,
194
        cache: BaseMultiModalProcessorCache | None = None,
195
        observability_config: ObservabilityConfig | None = None,
196
197
198
199
    ) -> Mapping[str, int]:
        """
        Get the maximum number of multi-modal input instances for each modality
        that are allowed per prompt for a model class.
200
        """
201
202
        if not model_config.is_multimodal_model:
            return {}
203

204
205
206
        processor = self.create_processor(
            model_config, observability_config, cache=cache
        )
207
        return processor.allowed_mm_limits
208
209
210

    def register_processor(
        self,
211
212
213
214
        processor: MultiModalProcessorFactory[_I],
        *,
        info: ProcessingInfoFactory[_I],
        dummy_inputs: DummyInputsBuilderFactory[_I],
215
216
    ):
        """
217
218
        Register a multi-modal processor to a model class. The processor
        is constructed lazily, hence a factory method should be passed.
219
220
221
222
223
224

        When the model receives multi-modal data, the provided function is
        invoked to transform the data into a dictionary of model inputs.
        """

        def wrapper(model_cls: N) -> N:
225
            if "_processor_factory" in model_cls.__dict__:
226
                logger.warning(
227
                    "Model class %s already has a multi-modal processor "
228
                    "registered to %s. It is overwritten by the new one.",
229
230
231
                    model_cls,
                    self,
                )
232

233
            model_cls._processor_factory = _ProcessorFactories(
234
235
236
237
                info=info,
                dummy_inputs=dummy_inputs,
                processor=processor,
            )
238
239
240
241
242

            return model_cls

        return wrapper

243
    def _get_model_cls(self, model_config: "ModelConfig") -> "SupportsMultiModal":
244
245
246
247
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
248
249
        assert hasattr(model_cls, "_processor_factory")
        return cast("SupportsMultiModal", model_cls)
250

251
252
253
    def _create_processing_ctx(
        self,
        model_config: "ModelConfig",
254
        observability_config: "ObservabilityConfig | None" = None,
255
256
        tokenizer: TokenizerLike | None = None,
    ) -> InputProcessingContext:
257
        if tokenizer is None:
258
259
            tokenizer = cached_tokenizer_from_config(model_config)

260
261
262
        return InputProcessingContext(
            model_config, tokenizer, observability_config=observability_config
        )
263

264
265
    def _create_processing_info(
        self,
266
        model_config: "ModelConfig",
267
        observability_config: "ObservabilityConfig | None" = None,
268
        *,
269
        tokenizer: TokenizerLike | None = None,
270
    ) -> BaseProcessingInfo:
271
        model_cls = self._get_model_cls(model_config)
272
        factories = model_cls._processor_factory
273
        ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
274
275
        return factories.info(ctx)

276
277
    def create_processor(
        self,
278
        model_config: "ModelConfig",
279
        observability_config: "ObservabilityConfig | None" = None,
280
        *,
281
        tokenizer: TokenizerLike | None = None,
282
        cache: BaseMultiModalProcessorCache | None = None,
283
    ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
284
285
286
        """
        Create a multi-modal processor for a specific model and tokenizer.
        """
287
288
289
        if not model_config.is_multimodal_model:
            raise ValueError(f"{model_config.model} is not a multimodal model")

290
        model_cls = self._get_model_cls(model_config)
291
        factories = model_cls._processor_factory
292

293
        ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
294

295
        return factories.build_processor(ctx, cache=cache)
296

297
    def get_dummy_mm_inputs(
298
        self,
299
        model_config: "ModelConfig",
300
        mm_counts: Mapping[str, int] | None = None,
301
        *,
302
        cache: BaseMultiModalProcessorCache | None = None,
303
        observability_config: ObservabilityConfig | None = None,
304
305
        processor: BaseMultiModalProcessor | None = None,
    ) -> MultiModalInputs:
306
307
308
        """
        Create dummy data for profiling the memory usage of a model.

309
        The model is identified by `model_config`.
310
        """
311
312
313
314
315
316
317
318
319
320
321
        seq_len = model_config.max_model_len

        if processor is None:
            processor = self.create_processor(
                model_config, observability_config, cache=cache
            )
        if mm_counts is None:
            mm_counts = processor.allowed_mm_limits

        processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
            seq_len=seq_len,
322
323
324
            mm_counts=mm_counts,
            mm_options=self._extract_mm_options(model_config),
        )
325
326
327
328
329
330
        mm_inputs = processor.apply(
            prompt=processor_inputs.prompt,
            mm_data=processor_inputs.mm_data,
            hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
            tokenization_kwargs=processor_inputs.tokenization_kwargs,
        )
331

332
333
334
335
        prompt_token_ids = mm_inputs["prompt_token_ids"]
        total_len = len(prompt_token_ids)
        if total_len < seq_len:
            prompt_token_ids.extend([0] * (seq_len - total_len))
336

337
        return mm_inputs
338

339
    def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
340
341
342
343
344
        """
        Get the maximum length of the encoder input for encoder-decoder models.
        """
        if not model_config.is_encoder_decoder:
            return 0
345
        max_tokens = self.get_max_tokens_per_item_by_modality(model_config)
346
347
348
349
350
        if not max_tokens:
            # TODO - this function assumes encoder-decoder models are
            # multimodal. This will need to change when adding support for more
            # than whisper.
            return 0
351
        assert len(max_tokens) == 1, (
352
353
            "Encoder-decoder models are expected "
            "to implement the multimodal interface with at most one modality."
354
        )
355
356
357

        first_modality = next(iter(max_tokens))
        return max_tokens[first_modality]