registry.py 13.2 KB
Newer Older
1
import functools
2
from collections import UserDict
3
from typing import (TYPE_CHECKING, Any, Dict, Mapping, Optional, Protocol,
4
                    Sequence, Type, TypeVar)
5

6
7
8
import torch.nn as nn

from vllm.inputs import InputProcessingContext
9
from vllm.logger import init_logger
10
from vllm.transformers_utils.tokenizer import AnyTokenizer
11
from vllm.utils import ClassRegistry
12

13
from .audio import AudioPlugin
14
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
15
from .image import ImagePlugin
16
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
17
from .processing import BaseMultiModalProcessor, ProcessingCache
18
from .utils import cached_get_tokenizer
19
from .video import VideoPlugin
20

21
22
23
if TYPE_CHECKING:
    from vllm.config import ModelConfig

24
25
logger = init_logger(__name__)

26
27
28
# TODO: Tune the MM cache size
MM_CACHE_SIZE = 256

29
30
31
N = TypeVar("N", bound=Type[nn.Module])


32
33
34
35
36
37
38
39
40
41
class MultiModalProcessorFactory(Protocol):
    """Constructs a :class:`MultiModalProcessor` instance from the context."""

    def __call__(
        self,
        ctx: InputProcessingContext,
        *,
        cache: Optional[ProcessingCache] = None,
    ) -> BaseMultiModalProcessor:
        ...
42

43

44
class _MultiModalLimits(UserDict["ModelConfig", Dict[str, int]]):
45
46
47
48
49
    """
    Wraps `_limits_by_model` for a more informative error message
    when attempting to access a model that does not exist.
    """

50
    def __getitem__(self, key: "ModelConfig") -> Dict[str, int]:
51
52
53
54
55
56
57
58
        try:
            return super().__getitem__(key)
        except KeyError as exc:
            msg = (f"Cannot find `mm_limits` for model={key.model}. Did you "
                   "forget to call `init_mm_limits_per_prompt`?")
            raise KeyError(msg) from exc


59
60
class MultiModalRegistry:
    """
61
62
    A registry that dispatches data processing to the
    :class:`~vllm.multimodal.MultiModalPlugin` for each modality.
63
64
    """

65
    DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), VideoPlugin())
66

67
    def __init__(
68
69
70
71
            self,
            *,
            plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
        self._plugins = {p.get_data_key(): p for p in plugins}
72

73
74
        self._processor_factories = ClassRegistry[nn.Module,
                                                  MultiModalProcessorFactory]()
75

76
77
78
79
80
        # This is used for non-multimodal models
        self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}

        self._limits_by_model = _MultiModalLimits()

81
82
        self._processing_cache = ProcessingCache(MM_CACHE_SIZE)

83
    def register_plugin(self, plugin: MultiModalPlugin) -> None:
84
85
86
87
        """
        Register a multi-modal plugin so it can be recognized by vLLM.

        See also:
88
            :ref:`adding-multimodal-plugin`
89
        """
90
        data_type_key = plugin.get_data_key()
91

92
        if data_type_key in self._plugins:
93
94
            logger.warning(
                "A plugin is already registered for data type %s, "
95
                "and will be overwritten by the new plugin %s.", data_type_key,
96
97
                plugin)

98
        self._plugins[data_type_key] = plugin
99

100
101
102
103
    def _get_plugin(self, data_type_key: str):
        plugin = self._plugins.get(data_type_key)
        if plugin is not None:
            return plugin
104

105
        msg = f"Unknown multi-modal data type: {data_type_key}"
106
107
        raise NotImplementedError(msg)

108
    def register_input_mapper(
109
        self,
110
        data_type_key: str,
111
        mapper: Optional[MultiModalInputMapper] = None,
112
    ):
113
        """
114
        Register an input mapper for a specific modality to a model class.
115

116
        See :meth:`MultiModalPlugin.register_input_mapper` for more details.
117
        """
118
        return self._get_plugin(data_type_key).register_input_mapper(mapper)
119

120
    def register_image_input_mapper(
121
        self,
122
        mapper: Optional[MultiModalInputMapper] = None,
123
    ):
124
        """
125
        Register an input mapper for image data to a model class.
126

127
        See :meth:`MultiModalPlugin.register_input_mapper` for more details.
128
        """
129
        return self.register_input_mapper("image", mapper)
130

131
132
    def map_input(
        self,
133
        model_config: "ModelConfig",
134
135
        data: MultiModalDataDict,
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
136
    ) -> MultiModalKwargs:
137
        """
138
        Apply an input mapper to the data passed to the model.
139
140
141
142
143

        The data belonging to each modality is passed to the corresponding
        plugin which in turn converts the data into into keyword arguments
        via the input mapper registered for that model.

144
        See :meth:`MultiModalPlugin.map_input` for more details.
145
146
147

        Note:
            This should be called after :meth:`init_mm_limits_per_prompt`.
148
        """
149
        merged_dict: Dict[str, NestedTensors] = {}
150
151

        for data_key, data_value in data.items():
152
            plugin = self._get_plugin(data_key)
153

154
155
156
157
158
159
160
161
            num_items = len(data_value) if isinstance(data_value, list) else 1
            max_items = self._limits_by_model[model_config][data_key]
            if num_items > max_items:
                raise ValueError(
                    f"You set {data_key}={max_items} (or defaulted to 1) in "
                    f"`--limit-mm-per-prompt`, but found {num_items} items "
                    "in the same prompt.")

162
163
            input_dict = plugin.map_input(model_config, data_value,
                                          mm_processor_kwargs)
164
165
166
167
168
169
170
171
            for input_key, input_tensor in input_dict.items():
                if input_key in merged_dict:
                    raise ValueError(f"The input mappers (keys={set(data)}) "
                                     f"resulted in a conflicting keyword "
                                     f"argument to `forward()`: {input_key}")

                merged_dict[input_key] = input_tensor

172
        return MultiModalKwargs(merged_dict)
173

174
    def create_input_mapper(self, model_config: "ModelConfig"):
175
        """
176
        Create an input mapper (see :meth:`map_input`) for a specific model.
177
        """
178
179
180
181
182
183
184
185
186
        # NOTE - we currently make the assumption that if a model has multiple
        # supported modalities, they take the same kwargs. For the default,
        # this could be an issue in the future if it falls back to two HF
        # resources and we can't inspect the signature easily since it's
        # getting initialized through the autoclass.
        #
        # If this is a problem in the future, we should revisit it, but since
        # it potentially introduces a lot of complexity for a currently
        # uncommon case, we do not for simplicity of both use & implementation
187
        return functools.partial(self.map_input, model_config)
188

189
190
191
192
193
    def register_max_multimodal_tokens(
        self,
        data_type_key: str,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
194
        """
195
196
197
        Register the maximum number of tokens, corresponding to a single
        instance of multimodal data belonging to a specific modality, that are
        passed to the language model for a model class.
198
199
200
201
202
203
204
205
206
        """
        return self._get_plugin(data_type_key) \
            .register_max_multimodal_tokens(max_mm_tokens)

    def register_max_image_tokens(
        self,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
        """
207
208
        Register the maximum number of image tokens, corresponding to a single
        image, that are passed to the language model for a model class.
209
210
211
        """
        return self.register_max_multimodal_tokens("image", max_mm_tokens)

212
213
214
215
216
217
218
219
220
221
222
    def get_max_tokens_per_item_by_modality(
        self,
        model_config: "ModelConfig",
    ) -> Mapping[str, int]:
        """
        Get the maximum number of tokens per data item from each modality
        for profiling the memory usage of a model.

        Note:
            This is currently directly used only in V1.
        """
223
224
225
        if self.has_processor(model_config):
            tokenizer = cached_get_tokenizer(model_config.tokenizer)
            processor = self.create_processor(model_config, tokenizer)
226
227
            seq_len = model_config.max_model_len
            return processor.get_mm_max_tokens_per_item(seq_len)
228
229
230
231
232
233

        return {
            key: plugin.get_max_multimodal_tokens(model_config)
            for key, plugin in self._plugins.items()
        }

234
235
236
237
    def get_max_tokens_by_modality(
        self,
        model_config: "ModelConfig",
    ) -> Mapping[str, int]:
238
        """
239
        Get the maximum number of tokens from each modality
240
        for profiling the memory usage of a model.
241

242
        See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
243
244
245
246
247
248

        Note:
            This should be called after :meth:`init_mm_limits_per_prompt`.
        """
        limits_per_plugin = self._limits_by_model[model_config]

249
        return {
250
251
252
            key: limits_per_plugin[key] * max_tokens_per_mm_item
            for key, max_tokens_per_mm_item in
            self.get_max_tokens_per_item_by_modality(model_config).items()
253
254
255
256
257
258
259
260
261
262
263
264
265
        }

    def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
        """
        Get the maximum number of multi-modal tokens
        for profiling the memory usage of a model.

        See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.

        Note:
            This should be called after :meth:`init_mm_limits_per_prompt`.
        """
        return sum(self.get_max_tokens_by_modality(model_config).values())
266
267
268

    def init_mm_limits_per_prompt(
        self,
269
        model_config: "ModelConfig",
270
271
272
273
274
275
276
277
278
279
    ) -> None:
        """
        Initialize the maximum number of multi-modal input instances for each
        modality that are allowed per prompt for a model class.
        """
        if model_config in self._limits_by_model:
            logger.warning(
                "`mm_limits` has already been set for model=%s, and will "
                "be overwritten by the new values.", model_config.model)

280
        multimodal_config = model_config.multimodal_config
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
        if multimodal_config is None:
            limits_per_plugin = self._disabled_limits_per_plugin
        else:
            config_limits_per_plugin = multimodal_config.limit_per_prompt

            extra_keys = config_limits_per_plugin.keys() - self._plugins.keys()
            if extra_keys:
                logger.warning(
                    "Detected extra keys in `--limit-mm-per-prompt` which "
                    "are not registered as multi-modal plugins: %s. "
                    "They will be ignored.", extra_keys)

            # NOTE: Currently the default is set to 1 for each plugin
            # TODO: Automatically determine the limits based on budget
            # once more models support multi-image inputs
            limits_per_plugin = {
                key: config_limits_per_plugin.get(key, 1)
                for key in self._plugins
            }

        self._limits_by_model[model_config] = limits_per_plugin

    def get_mm_limits_per_prompt(
        self,
305
        model_config: "ModelConfig",
306
307
308
309
310
311
312
    ) -> Mapping[str, int]:
        """
        Get the maximum number of multi-modal input instances for each modality
        that are allowed per prompt for a model class.

        Note:
            This should be called after :meth:`init_mm_limits_per_prompt`.
313
        """
314
        return self._limits_by_model[model_config]
315
316
317
318
319
320

    def register_processor(
        self,
        factory: MultiModalProcessorFactory,
    ):
        """
321
322
        Register a multi-modal processor to a model class. The processor
        is constructed lazily, hence a factory method should be passed.
323
324
325
326
327

        When the model receives multi-modal data, the provided function is
        invoked to transform the data into a dictionary of model inputs.

        See also:
328
329
            - :ref:`input-processing-pipeline`
            - :ref:`enabling-multimodal-inputs`
330
331
332
        """

        def wrapper(model_cls: N) -> N:
333
            if self._processor_factories.contains(model_cls, strict=True):
334
                logger.warning(
335
                    "Model class %s already has a multi-modal processor "
336
337
338
339
340
341
342
343
344
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

            self._processor_factories[model_cls] = factory

            return model_cls

        return wrapper

345
    def _get_model_cls(self, model_config: "ModelConfig"):
346
347
348
349
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
350
351
352
353
354
355
356
        return model_cls

    def has_processor(self, model_config: "ModelConfig") -> bool:
        """
        Test whether a multi-modal processor is defined for a specific model.
        """
        return self._get_model_cls(model_config) in self._processor_factories
357
358
359
360
361

    def create_processor(
        self,
        model_config: "ModelConfig",
        tokenizer: AnyTokenizer,
362
    ) -> BaseMultiModalProcessor:
363
364
365
        """
        Create a multi-modal processor for a specific model and tokenizer.
        """
366
        model_cls = self._get_model_cls(model_config)
367
368
369
        processor_factory = self._processor_factories[model_cls]

        ctx = InputProcessingContext(model_config, tokenizer)
370
371
372
373
        cache = (None if model_config.disable_mm_preprocessor_cache else
                 self._processing_cache)

        return processor_factory(ctx, cache=cache)