registry.py 12.1 KB
Newer Older
1
import functools
2
from collections import UserDict
3
4
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional,
                    Sequence, Type, TypeVar)
5

6
7
8
9
import torch.nn as nn
from typing_extensions import TypeAlias

from vllm.inputs import InputProcessingContext
10
from vllm.logger import init_logger
11
from vllm.transformers_utils.tokenizer import AnyTokenizer
12
from vllm.utils import ClassRegistry
13

14
from .audio import AudioPlugin
15
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
16
from .image import ImagePlugin
17
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
18
from .processing import BaseMultiModalProcessor
19
from .video import VideoPlugin
20

21
22
23
if TYPE_CHECKING:
    from vllm.config import ModelConfig

24
25
logger = init_logger(__name__)

26
27
28
N = TypeVar("N", bound=Type[nn.Module])

MultiModalProcessorFactory: TypeAlias = Callable[[InputProcessingContext],
29
                                                 BaseMultiModalProcessor]
30
31
32
33
34
35
"""
Constructs a :class:`MultiModalProcessor` instance from the context.

The processing metadata should be derived from the context.
"""

36

37
class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]):
38
39
40
41
42
    """
    Wraps `_limits_by_model` for a more informative error message
    when attempting to access a model that does not exist.
    """

43
    def __getitem__(self, key: "ModelConfig") -> Dict[str, int]:
44
45
46
47
48
49
50
51
        try:
            return super().__getitem__(key)
        except KeyError as exc:
            msg = (f"Cannot find `mm_limits` for model={key.model}. Did you "
                   "forget to call `init_mm_limits_per_prompt`?")
            raise KeyError(msg) from exc


52
53
class MultiModalRegistry:
    """
54
55
    A registry that dispatches data processing to the
    :class:`~vllm.multimodal.MultiModalPlugin` for each modality.
56
57
    """

58
    DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), VideoPlugin())
59

60
    def __init__(
61
62
63
64
            self,
            *,
            plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
        self._plugins = {p.get_data_key(): p for p in plugins}
65

66
67
        self._processor_factories = ClassRegistry[nn.Module,
                                                  MultiModalProcessorFactory]()
68

69
70
71
72
73
        # This is used for non-multimodal models
        self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}

        self._limits_by_model = _MultiModalLimits()

74
    def register_plugin(self, plugin: MultiModalPlugin) -> None:
75
76
77
78
79
80
        """
        Register a multi-modal plugin so it can be recognized by vLLM.

        See also:
            :ref:`adding_multimodal_plugin`
        """
81
        data_type_key = plugin.get_data_key()
82

83
        if data_type_key in self._plugins:
84
85
            logger.warning(
                "A plugin is already registered for data type %s, "
86
                "and will be overwritten by the new plugin %s.", data_type_key,
87
88
                plugin)

89
        self._plugins[data_type_key] = plugin
90

91
92
93
94
    def _get_plugin(self, data_type_key: str):
        plugin = self._plugins.get(data_type_key)
        if plugin is not None:
            return plugin
95

96
        msg = f"Unknown multi-modal data type: {data_type_key}"
97
98
        raise NotImplementedError(msg)

99
    def register_input_mapper(
100
        self,
101
        data_type_key: str,
102
        mapper: Optional[MultiModalInputMapper] = None,
103
    ):
104
        """
105
        Register an input mapper for a specific modality to a model class.
106

107
        See :meth:`MultiModalPlugin.register_input_mapper` for more details.
108
        """
109
        return self._get_plugin(data_type_key).register_input_mapper(mapper)
110

111
    def register_image_input_mapper(
112
        self,
113
        mapper: Optional[MultiModalInputMapper] = None,
114
    ):
115
        """
116
        Register an input mapper for image data to a model class.
117

118
        See :meth:`MultiModalPlugin.register_input_mapper` for more details.
119
        """
120
        return self.register_input_mapper("image", mapper)
121

122
123
    def map_input(
        self,
124
        model_config: "ModelConfig",
125
126
        data: MultiModalDataDict,
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
127
    ) -> MultiModalKwargs:
128
        """
129
        Apply an input mapper to the data passed to the model.
130
131
132
133
134

        The data belonging to each modality is passed to the corresponding
        plugin which in turn converts the data into into keyword arguments
        via the input mapper registered for that model.

135
        See :meth:`MultiModalPlugin.map_input` for more details.
136
137
138

        Note:
            This should be called after :meth:`init_mm_limits_per_prompt`.
139
        """
140
        merged_dict: Dict[str, NestedTensors] = {}
141
142

        for data_key, data_value in data.items():
143
            plugin = self._get_plugin(data_key)
144

145
146
147
148
149
150
151
152
            num_items = len(data_value) if isinstance(data_value, list) else 1
            max_items = self._limits_by_model[model_config][data_key]
            if num_items > max_items:
                raise ValueError(
                    f"You set {data_key}={max_items} (or defaulted to 1) in "
                    f"`--limit-mm-per-prompt`, but found {num_items} items "
                    "in the same prompt.")

153
154
            input_dict = plugin.map_input(model_config, data_value,
                                          mm_processor_kwargs)
155
156
157
158
159
160
161
162
            for input_key, input_tensor in input_dict.items():
                if input_key in merged_dict:
                    raise ValueError(f"The input mappers (keys={set(data)}) "
                                     f"resulted in a conflicting keyword "
                                     f"argument to `forward()`: {input_key}")

                merged_dict[input_key] = input_tensor

163
        return MultiModalKwargs(merged_dict)
164

165
    def create_input_mapper(self, model_config: "ModelConfig"):
166
        """
167
        Create an input mapper (see :meth:`map_input`) for a specific model.
168
        """
169
170
171
172
173
174
175
176
177
        # NOTE - we currently make the assumption that if a model has multiple
        # supported modalities, they take the same kwargs. For the default,
        # this could be an issue in the future if it falls back to two HF
        # resources and we can't inspect the signature easily since it's
        # getting initialized through the autoclass.
        #
        # If this is a problem in the future, we should revisit it, but since
        # it potentially introduces a lot of complexity for a currently
        # uncommon case, we do not for simplicity of both use & implementation
178
        return functools.partial(self.map_input, model_config)
179

180
181
182
183
184
    def register_max_multimodal_tokens(
        self,
        data_type_key: str,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
185
        """
186
187
188
        Register the maximum number of tokens, corresponding to a single
        instance of multimodal data belonging to a specific modality, that are
        passed to the language model for a model class.
189
190
191
192
193
194
195
196
197
        """
        return self._get_plugin(data_type_key) \
            .register_max_multimodal_tokens(max_mm_tokens)

    def register_max_image_tokens(
        self,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
        """
198
199
        Register the maximum number of image tokens, corresponding to a single
        image, that are passed to the language model for a model class.
200
201
202
        """
        return self.register_max_multimodal_tokens("image", max_mm_tokens)

203
204
205
206
    def get_max_tokens_by_modality(
        self,
        model_config: "ModelConfig",
    ) -> Mapping[str, int]:
207
        """
208
        Get the maximum number of tokens from each modality
209
        for profiling the memory usage of a model.
210

211
        See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
212
213
214
215
216
217

        Note:
            This should be called after :meth:`init_mm_limits_per_prompt`.
        """
        limits_per_plugin = self._limits_by_model[model_config]

218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
        return {
            key: (limits_per_plugin[key] *
                  plugin.get_max_multimodal_tokens(model_config))
            for key, plugin in self._plugins.items()
        }

    def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
        """
        Get the maximum number of multi-modal tokens
        for profiling the memory usage of a model.

        See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.

        Note:
            This should be called after :meth:`init_mm_limits_per_prompt`.
        """
        return sum(self.get_max_tokens_by_modality(model_config).values())
235
236
237

    def init_mm_limits_per_prompt(
        self,
238
        model_config: "ModelConfig",
239
240
241
242
243
244
245
246
247
248
    ) -> None:
        """
        Initialize the maximum number of multi-modal input instances for each
        modality that are allowed per prompt for a model class.
        """
        if model_config in self._limits_by_model:
            logger.warning(
                "`mm_limits` has already been set for model=%s, and will "
                "be overwritten by the new values.", model_config.model)

249
        multimodal_config = model_config.multimodal_config
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
        if multimodal_config is None:
            limits_per_plugin = self._disabled_limits_per_plugin
        else:
            config_limits_per_plugin = multimodal_config.limit_per_prompt

            extra_keys = config_limits_per_plugin.keys() - self._plugins.keys()
            if extra_keys:
                logger.warning(
                    "Detected extra keys in `--limit-mm-per-prompt` which "
                    "are not registered as multi-modal plugins: %s. "
                    "They will be ignored.", extra_keys)

            # NOTE: Currently the default is set to 1 for each plugin
            # TODO: Automatically determine the limits based on budget
            # once more models support multi-image inputs
            limits_per_plugin = {
                key: config_limits_per_plugin.get(key, 1)
                for key in self._plugins
            }

        self._limits_by_model[model_config] = limits_per_plugin

    def get_mm_limits_per_prompt(
        self,
274
        model_config: "ModelConfig",
275
276
277
278
279
280
281
    ) -> Mapping[str, int]:
        """
        Get the maximum number of multi-modal input instances for each modality
        that are allowed per prompt for a model class.

        Note:
            This should be called after :meth:`init_mm_limits_per_prompt`.
282
        """
283
        return self._limits_by_model[model_config]
284
285
286
287
288
289

    def register_processor(
        self,
        factory: MultiModalProcessorFactory,
    ):
        """
290
291
        Register a multi-modal processor to a model class. The processor
        is constructed lazily, hence a factory method should be passed.
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327

        When the model receives multi-modal data, the provided function is
        invoked to transform the data into a dictionary of model inputs.

        See also:
            - :ref:`input_processing_pipeline`
            - :ref:`enabling_multimodal_inputs`
        """

        def wrapper(model_cls: N) -> N:
            if model_cls in self._processor_factories:
                logger.warning(
                    "Model class %s already has an input mapper "
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

            self._processor_factories[model_cls] = factory

            return model_cls

        return wrapper

    def has_processor(self, model_config: "ModelConfig") -> bool:
        """
        Test whether a multi-modal processor is defined for a specific model.
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
        return model_cls in self._processor_factories

    def create_processor(
        self,
        model_config: "ModelConfig",
        tokenizer: AnyTokenizer,
328
    ) -> BaseMultiModalProcessor:
329
330
331
332
333
334
335
336
337
338
339
340
        """
        Create a multi-modal processor for a specific model and tokenizer.
        """

        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
        processor_factory = self._processor_factories[model_cls]

        ctx = InputProcessingContext(model_config, tokenizer)
        return processor_factory(ctx)