registry.py 16.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import functools
4
from collections import UserDict
5
6
7
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, Generic, Mapping, Optional,
                    Protocol, Sequence, Type, TypeVar)
8

9
10
import torch.nn as nn

11
from vllm.envs import VLLM_MM_INPUT_CACHE_SIZE
12
from vllm.inputs import InputProcessingContext
13
from vllm.logger import init_logger
14
15
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
                                               cached_tokenizer_from_config)
16
from vllm.utils import ClassRegistry
17

18
from .audio import AudioPlugin
19
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
20
from .image import ImagePlugin
21
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
22
23
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
                         ProcessingCache)
24
from .profiling import BaseDummyInputsBuilder, MultiModalProfiler
25
from .video import VideoPlugin
26

27
28
29
if TYPE_CHECKING:
    from vllm.config import ModelConfig

30
31
logger = init_logger(__name__)

32
N = TypeVar("N", bound=Type[nn.Module])
33
34
_I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
35
36


37
class ProcessingInfoFactory(Protocol[_I_co]):
38
39
40
41
42
    """Constructs a :class:`MultiModalProcessor` instance from the context."""

    def __call__(
        self,
        ctx: InputProcessingContext,
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    ) -> _I_co:
        ...


class DummyInputsBuilderFactory(Protocol[_I]):
    """
    Constructs a :class:`BaseDummyInputsBuilder` instance from the context.
    """

    def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]:
        ...


class MultiModalProcessorFactory(Protocol[_I]):
    """Constructs a :class:`MultiModalProcessor` instance from the context."""

    def __call__(
        self,
        info: _I,
        dummy_inputs: BaseDummyInputsBuilder[_I],
63
64
        *,
        cache: Optional[ProcessingCache] = None,
65
    ) -> BaseMultiModalProcessor[_I]:
66
        ...
67

68

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
@dataclass(frozen=True)
class _ProcessorFactories(Generic[_I]):
    info: ProcessingInfoFactory[_I]
    processor: MultiModalProcessorFactory[_I]
    dummy_inputs: DummyInputsBuilderFactory[_I]

    def build_processor(
        self,
        ctx: InputProcessingContext,
        *,
        cache: Optional[ProcessingCache] = None,
    ):
        info = self.info(ctx)
        dummy_inputs_builder = self.dummy_inputs(info)
        return self.processor(info, dummy_inputs_builder, cache=cache)


86
class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]):
87
88
89
90
91
    """
    Wraps `_limits_by_model` for a more informative error message
    when attempting to access a model that does not exist.
    """

92
    def __getitem__(self, key: "ModelConfig") -> Dict[str, int]:
93
94
95
96
97
98
99
100
        try:
            return super().__getitem__(key)
        except KeyError as exc:
            msg = (f"Cannot find `mm_limits` for model={key.model}. Did you "
                   "forget to call `init_mm_limits_per_prompt`?")
            raise KeyError(msg) from exc


101
102
class MultiModalRegistry:
    """
103
    A registry that dispatches data processing according to the model.
104
105
    """

106
    DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), VideoPlugin())
107

108
    def __init__(
109
110
111
112
            self,
            *,
            plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
        self._plugins = {p.get_data_key(): p for p in plugins}
113

114
        self._processor_factories = ClassRegistry[nn.Module,
115
                                                  _ProcessorFactories]()
116

117
118
119
120
121
        # This is used for non-multimodal models
        self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}

        self._limits_by_model = _MultiModalLimits()

122
        self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_SIZE)
123

124
    def register_plugin(self, plugin: MultiModalPlugin) -> None:
125
126
127
        """
        Register a multi-modal plugin so it can be recognized by vLLM.
        """
128
        data_type_key = plugin.get_data_key()
129

130
        if data_type_key in self._plugins:
131
132
            logger.warning(
                "A plugin is already registered for data type %s, "
133
                "and will be overwritten by the new plugin %s.", data_type_key,
134
135
                plugin)

136
        self._plugins[data_type_key] = plugin
137

138
139
140
141
    def _get_plugin(self, data_type_key: str):
        plugin = self._plugins.get(data_type_key)
        if plugin is not None:
            return plugin
142

143
        msg = f"Unknown multi-modal data type: {data_type_key}"
144
145
        raise NotImplementedError(msg)

146
    def register_input_mapper(
147
        self,
148
        data_type_key: str,
149
        mapper: Optional[MultiModalInputMapper] = None,
150
    ):
151
        """
152
        Register an input mapper for a specific modality to a model class.
153

154
        See :meth:`MultiModalPlugin.register_input_mapper` for more details.
155
        """
156
        return self._get_plugin(data_type_key).register_input_mapper(mapper)
157

158
    def register_image_input_mapper(
159
        self,
160
        mapper: Optional[MultiModalInputMapper] = None,
161
    ):
162
        """
163
        Register an input mapper for image data to a model class.
164

165
        See :meth:`MultiModalPlugin.register_input_mapper` for more details.
166
        """
167
        return self.register_input_mapper("image", mapper)
168

169
170
    def map_input(
        self,
171
        model_config: "ModelConfig",
172
173
        data: MultiModalDataDict,
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
174
    ) -> MultiModalKwargs:
175
        """
176
        Apply an input mapper to the data passed to the model.
177
178
179
180
181

        The data belonging to each modality is passed to the corresponding
        plugin which in turn converts the data into into keyword arguments
        via the input mapper registered for that model.

182
        See :meth:`MultiModalPlugin.map_input` for more details.
183
184
185

        Note:
            This should be called after :meth:`init_mm_limits_per_prompt`.
186
        """
187
        merged_dict: Dict[str, NestedTensors] = {}
188
189

        for data_key, data_value in data.items():
190
            plugin = self._get_plugin(data_key)
191

192
193
194
195
196
197
198
199
            num_items = len(data_value) if isinstance(data_value, list) else 1
            max_items = self._limits_by_model[model_config][data_key]
            if num_items > max_items:
                raise ValueError(
                    f"You set {data_key}={max_items} (or defaulted to 1) in "
                    f"`--limit-mm-per-prompt`, but found {num_items} items "
                    "in the same prompt.")

200
201
            input_dict = plugin.map_input(model_config, data_value,
                                          mm_processor_kwargs)
202
203
204
205
206
207
208
209
            for input_key, input_tensor in input_dict.items():
                if input_key in merged_dict:
                    raise ValueError(f"The input mappers (keys={set(data)}) "
                                     f"resulted in a conflicting keyword "
                                     f"argument to `forward()`: {input_key}")

                merged_dict[input_key] = input_tensor

210
        return MultiModalKwargs(merged_dict)
211

212
    def create_input_mapper(self, model_config: "ModelConfig"):
213
        """
214
        Create an input mapper (see :meth:`map_input`) for a specific model.
215
        """
216
217
218
219
220
221
222
223
224
        # NOTE - we currently make the assumption that if a model has multiple
        # supported modalities, they take the same kwargs. For the default,
        # this could be an issue in the future if it falls back to two HF
        # resources and we can't inspect the signature easily since it's
        # getting initialized through the autoclass.
        #
        # If this is a problem in the future, we should revisit it, but since
        # it potentially introduces a lot of complexity for a currently
        # uncommon case, we do not for simplicity of both use & implementation
225
        return functools.partial(self.map_input, model_config)
226

227
228
229
230
231
    def register_max_multimodal_tokens(
        self,
        data_type_key: str,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
232
        """
233
234
235
        Register the maximum number of tokens, corresponding to a single
        instance of multimodal data belonging to a specific modality, that are
        passed to the language model for a model class.
236
237
238
239
240
241
242
243
244
        """
        return self._get_plugin(data_type_key) \
            .register_max_multimodal_tokens(max_mm_tokens)

    def register_max_image_tokens(
        self,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
        """
245
246
        Register the maximum number of image tokens, corresponding to a single
        image, that are passed to the language model for a model class.
247
248
249
        """
        return self.register_max_multimodal_tokens("image", max_mm_tokens)

250
251
252
253
254
    def get_max_tokens_per_item_by_modality(
        self,
        model_config: "ModelConfig",
    ) -> Mapping[str, int]:
        """
255
256
        Get the maximum number of tokens per data item from each modality based 
        on underlying model configuration.
257
        """
258
        if self.has_processor(model_config):
259
            tokenizer = cached_tokenizer_from_config(model_config)
260
261
262
            processor = self.create_processor(model_config,
                                              tokenizer,
                                              disable_cache=True)
263
            seq_len = model_config.max_model_len
264
265
266
            mm_limits = self.get_mm_limits_per_prompt(model_config)
            return processor.info.get_mm_max_tokens_per_item(
                seq_len, mm_limits)
267
268
269
270
271
272

        return {
            key: plugin.get_max_multimodal_tokens(model_config)
            for key, plugin in self._plugins.items()
        }

273
274
275
276
277
278
279
280
281
282
283
284
285
    def get_max_tokens_per_item_by_nonzero_modality(
        self,
        model_config: "ModelConfig",
    ) -> Mapping[str, int]:
        """
        Get the maximum number of tokens per data item from each modality based
        on underlying model configuration, excluding modalities that user 
        explicitly disabled via `limit_mm_per_prompt`.

        Note:
            This is currently directly used only in V1 for profiling the memory 
            usage of a model.
        """
286
        mm_limits = self.get_mm_limits_per_prompt(model_config)
287
288
289
290
291

        return {
            key: max_tokens_per_mm_item
            for key, max_tokens_per_mm_item in
            self.get_max_tokens_per_item_by_modality(model_config).items()
292
            if mm_limits[key] > 0
293
294
        }

295
296
297
298
    def get_max_tokens_by_modality(
        self,
        model_config: "ModelConfig",
    ) -> Mapping[str, int]:
299
        """
300
        Get the maximum number of tokens from each modality
301
        for profiling the memory usage of a model.
302

303
        See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
304
305
306
307

        Note:
            This should be called after :meth:`init_mm_limits_per_prompt`.
        """
308
        mm_limits = self.get_mm_limits_per_prompt(model_config)
309

310
        return {
311
            key: mm_limits[key] * max_tokens_per_mm_item
312
313
            for key, max_tokens_per_mm_item in
            self.get_max_tokens_per_item_by_modality(model_config).items()
314
315
316
317
318
319
320
321
322
323
324
325
326
        }

    def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
        """
        Get the maximum number of multi-modal tokens
        for profiling the memory usage of a model.

        See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.

        Note:
            This should be called after :meth:`init_mm_limits_per_prompt`.
        """
        return sum(self.get_max_tokens_by_modality(model_config).values())
327
328
329

    def init_mm_limits_per_prompt(
        self,
330
        model_config: "ModelConfig",
331
332
333
334
335
336
337
338
339
340
    ) -> None:
        """
        Initialize the maximum number of multi-modal input instances for each
        modality that are allowed per prompt for a model class.
        """
        if model_config in self._limits_by_model:
            logger.warning(
                "`mm_limits` has already been set for model=%s, and will "
                "be overwritten by the new values.", model_config.model)

341
        multimodal_config = model_config.multimodal_config
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
        if multimodal_config is None:
            limits_per_plugin = self._disabled_limits_per_plugin
        else:
            config_limits_per_plugin = multimodal_config.limit_per_prompt

            extra_keys = config_limits_per_plugin.keys() - self._plugins.keys()
            if extra_keys:
                logger.warning(
                    "Detected extra keys in `--limit-mm-per-prompt` which "
                    "are not registered as multi-modal plugins: %s. "
                    "They will be ignored.", extra_keys)

            # NOTE: Currently the default is set to 1 for each plugin
            # TODO: Automatically determine the limits based on budget
            # once more models support multi-image inputs
            limits_per_plugin = {
                key: config_limits_per_plugin.get(key, 1)
                for key in self._plugins
            }

        self._limits_by_model[model_config] = limits_per_plugin

    def get_mm_limits_per_prompt(
        self,
366
        model_config: "ModelConfig",
367
368
369
370
371
372
373
    ) -> Mapping[str, int]:
        """
        Get the maximum number of multi-modal input instances for each modality
        that are allowed per prompt for a model class.

        Note:
            This should be called after :meth:`init_mm_limits_per_prompt`.
374
        """
375
        if self.has_processor(model_config):
376
            tokenizer = cached_tokenizer_from_config(model_config)
377
378
379
            processor = self.create_processor(model_config,
                                              tokenizer,
                                              disable_cache=True)
380
381
382
            profiler = MultiModalProfiler(processor)
            return profiler.get_mm_limits()

383
        return self._limits_by_model[model_config]
384
385
386

    def register_processor(
        self,
387
388
389
390
        processor: MultiModalProcessorFactory[_I],
        *,
        info: ProcessingInfoFactory[_I],
        dummy_inputs: DummyInputsBuilderFactory[_I],
391
392
    ):
        """
393
394
        Register a multi-modal processor to a model class. The processor
        is constructed lazily, hence a factory method should be passed.
395
396
397
398
399

        When the model receives multi-modal data, the provided function is
        invoked to transform the data into a dictionary of model inputs.

        See also:
400
            :ref:`mm-processing`
401
402
403
        """

        def wrapper(model_cls: N) -> N:
404
            if self._processor_factories.contains(model_cls, strict=True):
405
                logger.warning(
406
                    "Model class %s already has a multi-modal processor "
407
408
409
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

410
411
412
413
414
            self._processor_factories[model_cls] = _ProcessorFactories(
                info=info,
                dummy_inputs=dummy_inputs,
                processor=processor,
            )
415
416
417
418
419

            return model_cls

        return wrapper

420
    def _get_model_cls(self, model_config: "ModelConfig"):
421
422
423
424
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
425
426
427
428
429
        return model_cls

    def has_processor(self, model_config: "ModelConfig") -> bool:
        """
        Test whether a multi-modal processor is defined for a specific model.
430
431
432

        See also:
            :ref:`mm-processing`
433
434
        """
        return self._get_model_cls(model_config) in self._processor_factories
435
436
437
438
439

    def create_processor(
        self,
        model_config: "ModelConfig",
        tokenizer: AnyTokenizer,
440
441
        *,
        disable_cache: Optional[bool] = None,
442
    ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
443
444
        """
        Create a multi-modal processor for a specific model and tokenizer.
445
446
447

        See also:
            :ref:`mm-processing`
448
        """
449
450
451
        if disable_cache is None:
            disable_cache = model_config.disable_mm_preprocessor_cache

452
        model_cls = self._get_model_cls(model_config)
453
        factories = self._processor_factories[model_cls]
454
455

        ctx = InputProcessingContext(model_config, tokenizer)
456
        cache = None if disable_cache else self._processing_cache
457

458
        return factories.build_processor(ctx, cache=cache)