registry.py 14.7 KB
Newer Older
1
import functools
2
from collections import UserDict
3
4
5
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, Generic, Mapping, Optional,
                    Protocol, Sequence, Type, TypeVar)
6

7
8
9
import torch.nn as nn

from vllm.inputs import InputProcessingContext
10
from vllm.logger import init_logger
11
from vllm.transformers_utils.tokenizer import AnyTokenizer
12
from vllm.utils import ClassRegistry
13

14
from .audio import AudioPlugin
15
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
16
from .image import ImagePlugin
17
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
18
19
20
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
                         ProcessingCache)
from .profiling import BaseDummyInputsBuilder
21
from .utils import cached_get_tokenizer
22
from .video import VideoPlugin
23

24
25
26
if TYPE_CHECKING:
    from vllm.config import ModelConfig

27
28
logger = init_logger(__name__)

29
30
31
# TODO: Tune the MM cache size
MM_CACHE_SIZE = 256

32
N = TypeVar("N", bound=Type[nn.Module])
33
34
_I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
35
36


37
class ProcessingInfoFactory(Protocol[_I_co]):
38
39
40
41
42
    """Constructs a :class:`MultiModalProcessor` instance from the context."""

    def __call__(
        self,
        ctx: InputProcessingContext,
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    ) -> _I_co:
        ...


class DummyInputsBuilderFactory(Protocol[_I]):
    """
    Constructs a :class:`BaseDummyInputsBuilder` instance from the context.
    """

    def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]:
        ...


class MultiModalProcessorFactory(Protocol[_I]):
    """Constructs a :class:`MultiModalProcessor` instance from the context."""

    def __call__(
        self,
        info: _I,
        dummy_inputs: BaseDummyInputsBuilder[_I],
63
64
        *,
        cache: Optional[ProcessingCache] = None,
65
    ) -> BaseMultiModalProcessor[_I]:
66
        ...
67

68

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
@dataclass(frozen=True)
class _ProcessorFactories(Generic[_I]):
    info: ProcessingInfoFactory[_I]
    processor: MultiModalProcessorFactory[_I]
    dummy_inputs: DummyInputsBuilderFactory[_I]

    def build_processor(
        self,
        ctx: InputProcessingContext,
        *,
        cache: Optional[ProcessingCache] = None,
    ):
        info = self.info(ctx)
        dummy_inputs_builder = self.dummy_inputs(info)
        return self.processor(info, dummy_inputs_builder, cache=cache)


86
class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]):
87
88
89
90
91
    """
    Wraps `_limits_by_model` for a more informative error message
    when attempting to access a model that does not exist.
    """

92
    def __getitem__(self, key: "ModelConfig") -> Dict[str, int]:
93
94
95
96
97
98
99
100
        try:
            return super().__getitem__(key)
        except KeyError as exc:
            msg = (f"Cannot find `mm_limits` for model={key.model}. Did you "
                   "forget to call `init_mm_limits_per_prompt`?")
            raise KeyError(msg) from exc


101
102
class MultiModalRegistry:
    """
103
104
    A registry that dispatches data processing to the
    :class:`~vllm.multimodal.MultiModalPlugin` for each modality.
105
106
    """

107
    DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), VideoPlugin())
108

109
    def __init__(
110
111
112
113
            self,
            *,
            plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
        self._plugins = {p.get_data_key(): p for p in plugins}
114

115
        self._processor_factories = ClassRegistry[nn.Module,
116
                                                  _ProcessorFactories]()
117

118
119
120
121
122
        # This is used for non-multimodal models
        self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}

        self._limits_by_model = _MultiModalLimits()

123
124
        self._processing_cache = ProcessingCache(MM_CACHE_SIZE)

125
    def register_plugin(self, plugin: MultiModalPlugin) -> None:
126
127
128
129
        """
        Register a multi-modal plugin so it can be recognized by vLLM.

        See also:
130
            :ref:`adding-multimodal-plugin`
131
        """
132
        data_type_key = plugin.get_data_key()
133

134
        if data_type_key in self._plugins:
135
136
            logger.warning(
                "A plugin is already registered for data type %s, "
137
                "and will be overwritten by the new plugin %s.", data_type_key,
138
139
                plugin)

140
        self._plugins[data_type_key] = plugin
141

142
143
144
145
    def _get_plugin(self, data_type_key: str):
        plugin = self._plugins.get(data_type_key)
        if plugin is not None:
            return plugin
146

147
        msg = f"Unknown multi-modal data type: {data_type_key}"
148
149
        raise NotImplementedError(msg)

150
    def register_input_mapper(
151
        self,
152
        data_type_key: str,
153
        mapper: Optional[MultiModalInputMapper] = None,
154
    ):
155
        """
156
        Register an input mapper for a specific modality to a model class.
157

158
        See :meth:`MultiModalPlugin.register_input_mapper` for more details.
159
        """
160
        return self._get_plugin(data_type_key).register_input_mapper(mapper)
161

162
    def register_image_input_mapper(
163
        self,
164
        mapper: Optional[MultiModalInputMapper] = None,
165
    ):
166
        """
167
        Register an input mapper for image data to a model class.
168

169
        See :meth:`MultiModalPlugin.register_input_mapper` for more details.
170
        """
171
        return self.register_input_mapper("image", mapper)
172

173
174
    def map_input(
        self,
175
        model_config: "ModelConfig",
176
177
        data: MultiModalDataDict,
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
178
    ) -> MultiModalKwargs:
179
        """
180
        Apply an input mapper to the data passed to the model.
181
182
183
184
185

        The data belonging to each modality is passed to the corresponding
        plugin which in turn converts the data into into keyword arguments
        via the input mapper registered for that model.

186
        See :meth:`MultiModalPlugin.map_input` for more details.
187
188
189

        Note:
            This should be called after :meth:`init_mm_limits_per_prompt`.
190
        """
191
        merged_dict: Dict[str, NestedTensors] = {}
192
193

        for data_key, data_value in data.items():
194
            plugin = self._get_plugin(data_key)
195

196
197
198
199
200
201
202
203
            num_items = len(data_value) if isinstance(data_value, list) else 1
            max_items = self._limits_by_model[model_config][data_key]
            if num_items > max_items:
                raise ValueError(
                    f"You set {data_key}={max_items} (or defaulted to 1) in "
                    f"`--limit-mm-per-prompt`, but found {num_items} items "
                    "in the same prompt.")

204
205
            input_dict = plugin.map_input(model_config, data_value,
                                          mm_processor_kwargs)
206
207
208
209
210
211
212
213
            for input_key, input_tensor in input_dict.items():
                if input_key in merged_dict:
                    raise ValueError(f"The input mappers (keys={set(data)}) "
                                     f"resulted in a conflicting keyword "
                                     f"argument to `forward()`: {input_key}")

                merged_dict[input_key] = input_tensor

214
        return MultiModalKwargs(merged_dict)
215

216
    def create_input_mapper(self, model_config: "ModelConfig"):
217
        """
218
        Create an input mapper (see :meth:`map_input`) for a specific model.
219
        """
220
221
222
223
224
225
226
227
228
        # NOTE - we currently make the assumption that if a model has multiple
        # supported modalities, they take the same kwargs. For the default,
        # this could be an issue in the future if it falls back to two HF
        # resources and we can't inspect the signature easily since it's
        # getting initialized through the autoclass.
        #
        # If this is a problem in the future, we should revisit it, but since
        # it potentially introduces a lot of complexity for a currently
        # uncommon case, we do not for simplicity of both use & implementation
229
        return functools.partial(self.map_input, model_config)
230

231
232
233
234
235
    def register_max_multimodal_tokens(
        self,
        data_type_key: str,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
236
        """
237
238
239
        Register the maximum number of tokens, corresponding to a single
        instance of multimodal data belonging to a specific modality, that are
        passed to the language model for a model class.
240
241
242
243
244
245
246
247
248
        """
        return self._get_plugin(data_type_key) \
            .register_max_multimodal_tokens(max_mm_tokens)

    def register_max_image_tokens(
        self,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
        """
249
250
        Register the maximum number of image tokens, corresponding to a single
        image, that are passed to the language model for a model class.
251
252
253
        """
        return self.register_max_multimodal_tokens("image", max_mm_tokens)

254
255
256
257
258
259
260
261
262
263
264
    def get_max_tokens_per_item_by_modality(
        self,
        model_config: "ModelConfig",
    ) -> Mapping[str, int]:
        """
        Get the maximum number of tokens per data item from each modality
        for profiling the memory usage of a model.

        Note:
            This is currently directly used only in V1.
        """
265
266
267
        if self.has_processor(model_config):
            tokenizer = cached_get_tokenizer(model_config.tokenizer)
            processor = self.create_processor(model_config, tokenizer)
268
            seq_len = model_config.max_model_len
269
            return processor.info.get_mm_max_tokens_per_item(seq_len)
270
271
272
273
274
275

        return {
            key: plugin.get_max_multimodal_tokens(model_config)
            for key, plugin in self._plugins.items()
        }

276
277
278
279
    def get_max_tokens_by_modality(
        self,
        model_config: "ModelConfig",
    ) -> Mapping[str, int]:
280
        """
281
        Get the maximum number of tokens from each modality
282
        for profiling the memory usage of a model.
283

284
        See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
285
286
287
288
289
290

        Note:
            This should be called after :meth:`init_mm_limits_per_prompt`.
        """
        limits_per_plugin = self._limits_by_model[model_config]

291
        return {
292
293
294
            key: limits_per_plugin[key] * max_tokens_per_mm_item
            for key, max_tokens_per_mm_item in
            self.get_max_tokens_per_item_by_modality(model_config).items()
295
296
297
298
299
300
301
302
303
304
305
306
307
        }

    def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
        """
        Get the maximum number of multi-modal tokens
        for profiling the memory usage of a model.

        See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.

        Note:
            This should be called after :meth:`init_mm_limits_per_prompt`.
        """
        return sum(self.get_max_tokens_by_modality(model_config).values())
308
309
310

    def init_mm_limits_per_prompt(
        self,
311
        model_config: "ModelConfig",
312
313
314
315
316
317
318
319
320
321
    ) -> None:
        """
        Initialize the maximum number of multi-modal input instances for each
        modality that are allowed per prompt for a model class.
        """
        if model_config in self._limits_by_model:
            logger.warning(
                "`mm_limits` has already been set for model=%s, and will "
                "be overwritten by the new values.", model_config.model)

322
        multimodal_config = model_config.multimodal_config
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
        if multimodal_config is None:
            limits_per_plugin = self._disabled_limits_per_plugin
        else:
            config_limits_per_plugin = multimodal_config.limit_per_prompt

            extra_keys = config_limits_per_plugin.keys() - self._plugins.keys()
            if extra_keys:
                logger.warning(
                    "Detected extra keys in `--limit-mm-per-prompt` which "
                    "are not registered as multi-modal plugins: %s. "
                    "They will be ignored.", extra_keys)

            # NOTE: Currently the default is set to 1 for each plugin
            # TODO: Automatically determine the limits based on budget
            # once more models support multi-image inputs
            limits_per_plugin = {
                key: config_limits_per_plugin.get(key, 1)
                for key in self._plugins
            }

        self._limits_by_model[model_config] = limits_per_plugin

    def get_mm_limits_per_prompt(
        self,
347
        model_config: "ModelConfig",
348
349
350
351
352
353
354
    ) -> Mapping[str, int]:
        """
        Get the maximum number of multi-modal input instances for each modality
        that are allowed per prompt for a model class.

        Note:
            This should be called after :meth:`init_mm_limits_per_prompt`.
355
        """
356
        return self._limits_by_model[model_config]
357
358
359

    def register_processor(
        self,
360
361
362
363
        processor: MultiModalProcessorFactory[_I],
        *,
        info: ProcessingInfoFactory[_I],
        dummy_inputs: DummyInputsBuilderFactory[_I],
364
365
    ):
        """
366
367
        Register a multi-modal processor to a model class. The processor
        is constructed lazily, hence a factory method should be passed.
368
369
370
371
372

        When the model receives multi-modal data, the provided function is
        invoked to transform the data into a dictionary of model inputs.

        See also:
373
374
            - :ref:`input-processing-pipeline`
            - :ref:`enabling-multimodal-inputs`
375
376
377
        """

        def wrapper(model_cls: N) -> N:
378
            if self._processor_factories.contains(model_cls, strict=True):
379
                logger.warning(
380
                    "Model class %s already has a multi-modal processor "
381
382
383
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

384
385
386
387
388
            self._processor_factories[model_cls] = _ProcessorFactories(
                info=info,
                dummy_inputs=dummy_inputs,
                processor=processor,
            )
389
390
391
392
393

            return model_cls

        return wrapper

394
    def _get_model_cls(self, model_config: "ModelConfig"):
395
396
397
398
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
399
400
401
402
403
404
405
        return model_cls

    def has_processor(self, model_config: "ModelConfig") -> bool:
        """
        Test whether a multi-modal processor is defined for a specific model.
        """
        return self._get_model_cls(model_config) in self._processor_factories
406
407
408
409
410

    def create_processor(
        self,
        model_config: "ModelConfig",
        tokenizer: AnyTokenizer,
411
    ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
412
413
414
        """
        Create a multi-modal processor for a specific model and tokenizer.
        """
415
        model_cls = self._get_model_cls(model_config)
416
        factories = self._processor_factories[model_cls]
417
418

        ctx = InputProcessingContext(model_config, tokenizer)
419
420
421
        cache = (None if model_config.disable_mm_preprocessor_cache else
                 self._processing_cache)

422
        return factories.build_processor(ctx, cache=cache)