registry.py 17.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import functools
4
import json
5
from collections import UserDict
6
from collections.abc import Mapping, Sequence
7
from dataclasses import dataclass
8
from typing import TYPE_CHECKING, Any, Generic, Optional, Protocol, TypeVar
9

10
11
import torch.nn as nn

12
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
13
from vllm.inputs import InputProcessingContext
14
from vllm.logger import init_logger
15
16
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
                                               cached_tokenizer_from_config)
17
from vllm.utils import ClassRegistry
18

19
from .audio import AudioPlugin
20
from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc
21
from .image import ImagePlugin
22
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
23
24
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
                         ProcessingCache)
25
26
from .profiling import (BaseDummyInputsBuilder, DummyDecoderData,
                        DummyEncoderData, MultiModalProfiler)
27
from .video import VideoPlugin
28

29
30
31
if TYPE_CHECKING:
    from vllm.config import ModelConfig

32
33
logger = init_logger(__name__)

34
N = TypeVar("N", bound=type[nn.Module])
35
36
_I = TypeVar("_I", bound=BaseProcessingInfo)
_I_co = TypeVar("_I_co", bound=BaseProcessingInfo, covariant=True)
37
38


39
class ProcessingInfoFactory(Protocol[_I_co]):
40
41
42
43
44
    """Constructs a :class:`MultiModalProcessor` instance from the context."""

    def __call__(
        self,
        ctx: InputProcessingContext,
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    ) -> _I_co:
        ...


class DummyInputsBuilderFactory(Protocol[_I]):
    """
    Constructs a :class:`BaseDummyInputsBuilder` instance from the context.
    """

    def __call__(self, info: _I) -> BaseDummyInputsBuilder[_I]:
        ...


class MultiModalProcessorFactory(Protocol[_I]):
    """Constructs a :class:`MultiModalProcessor` instance from the context."""

    def __call__(
        self,
        info: _I,
        dummy_inputs: BaseDummyInputsBuilder[_I],
65
66
        *,
        cache: Optional[ProcessingCache] = None,
67
    ) -> BaseMultiModalProcessor[_I]:
68
        ...
69

70

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
@dataclass(frozen=True)
class _ProcessorFactories(Generic[_I]):
    info: ProcessingInfoFactory[_I]
    processor: MultiModalProcessorFactory[_I]
    dummy_inputs: DummyInputsBuilderFactory[_I]

    def build_processor(
        self,
        ctx: InputProcessingContext,
        *,
        cache: Optional[ProcessingCache] = None,
    ):
        info = self.info(ctx)
        dummy_inputs_builder = self.dummy_inputs(info)
        return self.processor(info, dummy_inputs_builder, cache=cache)


88
class _MultiModalLimits(UserDict["ModelConfig", dict[str, int]]):
89
90
91
92
93
    """
    Wraps `_limits_by_model` for a more informative error message
    when attempting to access a model that does not exist.
    """

94
    def __getitem__(self, key: "ModelConfig") -> dict[str, int]:
95
96
97
98
99
100
101
102
        try:
            return super().__getitem__(key)
        except KeyError as exc:
            msg = (f"Cannot find `mm_limits` for model={key.model}. Did you "
                   "forget to call `init_mm_limits_per_prompt`?")
            raise KeyError(msg) from exc


103
104
class MultiModalRegistry:
    """
105
    A registry that dispatches data processing according to the model.
106
107
    """

108
    DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), VideoPlugin())
109

110
    def __init__(
111
112
113
114
            self,
            *,
            plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
        self._plugins = {p.get_data_key(): p for p in plugins}
115

116
        self._processor_factories = ClassRegistry[nn.Module,
117
                                                  _ProcessorFactories]()
118

119
120
121
122
123
        # This is used for non-multimodal models
        self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}

        self._limits_by_model = _MultiModalLimits()

124
        self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB)
125

126
    def register_plugin(self, plugin: MultiModalPlugin) -> None:
127
128
129
        """
        Register a multi-modal plugin so it can be recognized by vLLM.
        """
130
        data_type_key = plugin.get_data_key()
131

132
        if data_type_key in self._plugins:
133
134
            logger.warning(
                "A plugin is already registered for data type %s, "
135
                "and will be overwritten by the new plugin %s.", data_type_key,
136
137
                plugin)

138
        self._plugins[data_type_key] = plugin
139

140
141
142
143
    def _get_plugin(self, data_type_key: str):
        plugin = self._plugins.get(data_type_key)
        if plugin is not None:
            return plugin
144

145
        msg = f"Unknown multi-modal data type: {data_type_key}"
146
147
        raise NotImplementedError(msg)

148
    def register_input_mapper(
149
        self,
150
        data_type_key: str,
151
        mapper: Optional[MultiModalInputMapper] = None,
152
    ):
153
        """
154
        Register an input mapper for a specific modality to a model class.
155

156
        See :meth:`MultiModalPlugin.register_input_mapper` for more details.
157
        """
158
        return self._get_plugin(data_type_key).register_input_mapper(mapper)
159

160
    def register_image_input_mapper(
161
        self,
162
        mapper: Optional[MultiModalInputMapper] = None,
163
    ):
164
        """
165
        Register an input mapper for image data to a model class.
166

167
        See :meth:`MultiModalPlugin.register_input_mapper` for more details.
168
        """
169
        return self.register_input_mapper("image", mapper)
170

171
172
    def map_input(
        self,
173
        model_config: "ModelConfig",
174
        data: MultiModalDataDict,
175
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
176
    ) -> MultiModalKwargs:
177
        """
178
        Apply an input mapper to the data passed to the model.
179
180
181
182
183

        The data belonging to each modality is passed to the corresponding
        plugin which in turn converts the data into into keyword arguments
        via the input mapper registered for that model.

184
        See :meth:`MultiModalPlugin.map_input` for more details.
185
186
187

        Note:
            This should be called after :meth:`init_mm_limits_per_prompt`.
188
        """
189
        merged_dict = dict[str, NestedTensors]()
190
191

        for data_key, data_value in data.items():
192
            plugin = self._get_plugin(data_key)
193

194
195
196
197
            num_items = len(data_value) if isinstance(data_value, list) else 1
            max_items = self._limits_by_model[model_config][data_key]
            if num_items > max_items:
                raise ValueError(
198
199
200
                    f"You set '{json.dumps({data_key: max_items})}' (or "
                    "defaulted to 1) in `--limit-mm-per-prompt`, but found "
                    f"{num_items} items in the same prompt.")
201

202
203
            input_dict = plugin.map_input(model_config, data_value,
                                          mm_processor_kwargs)
204
205
206
207
208
209
210
211
            for input_key, input_tensor in input_dict.items():
                if input_key in merged_dict:
                    raise ValueError(f"The input mappers (keys={set(data)}) "
                                     f"resulted in a conflicting keyword "
                                     f"argument to `forward()`: {input_key}")

                merged_dict[input_key] = input_tensor

212
        return MultiModalKwargs(merged_dict)
213

214
    def create_input_mapper(self, model_config: "ModelConfig"):
215
        """
216
        Create an input mapper (see :meth:`map_input`) for a specific model.
217
        """
218
219
220
221
222
223
224
225
226
        # NOTE - we currently make the assumption that if a model has multiple
        # supported modalities, they take the same kwargs. For the default,
        # this could be an issue in the future if it falls back to two HF
        # resources and we can't inspect the signature easily since it's
        # getting initialized through the autoclass.
        #
        # If this is a problem in the future, we should revisit it, but since
        # it potentially introduces a lot of complexity for a currently
        # uncommon case, we do not for simplicity of both use & implementation
227
        return functools.partial(self.map_input, model_config)
228

229
230
231
232
233
    def register_max_multimodal_tokens(
        self,
        data_type_key: str,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
234
        """
235
236
237
        Register the maximum number of tokens, corresponding to a single
        instance of multimodal data belonging to a specific modality, that are
        passed to the language model for a model class.
238
239
240
241
242
243
244
245
246
        """
        return self._get_plugin(data_type_key) \
            .register_max_multimodal_tokens(max_mm_tokens)

    def register_max_image_tokens(
        self,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
        """
247
248
        Register the maximum number of image tokens, corresponding to a single
        image, that are passed to the language model for a model class.
249
250
251
        """
        return self.register_max_multimodal_tokens("image", max_mm_tokens)

252
253
254
255
256
    def get_max_tokens_per_item_by_modality(
        self,
        model_config: "ModelConfig",
    ) -> Mapping[str, int]:
        """
257
258
        Get the maximum number of tokens per data item from each modality based 
        on underlying model configuration.
259
        """
260
        if self.has_processor(model_config):
261
            processor = self.create_processor(model_config, disable_cache=True)
262
263
            profiler = MultiModalProfiler(processor)

264
            seq_len = model_config.max_model_len
265
            mm_limits = self.get_mm_limits_per_prompt(model_config)
266
267
268

            return profiler.get_mm_max_tokens(
                seq_len,
269
270
271
272
                {
                    modality: 1
                    for modality, limit in mm_limits.items() if limit > 0
                },
273
            )
274
275
276
277
278
279

        return {
            key: plugin.get_max_multimodal_tokens(model_config)
            for key, plugin in self._plugins.items()
        }

280
281
282
283
284
285
286
287
288
289
290
291
292
    def get_max_tokens_per_item_by_nonzero_modality(
        self,
        model_config: "ModelConfig",
    ) -> Mapping[str, int]:
        """
        Get the maximum number of tokens per data item from each modality based
        on underlying model configuration, excluding modalities that user 
        explicitly disabled via `limit_mm_per_prompt`.

        Note:
            This is currently directly used only in V1 for profiling the memory 
            usage of a model.
        """
293
        mm_limits = self.get_mm_limits_per_prompt(model_config)
294
295
296
297
298

        return {
            key: max_tokens_per_mm_item
            for key, max_tokens_per_mm_item in
            self.get_max_tokens_per_item_by_modality(model_config).items()
299
            if mm_limits[key] > 0
300
301
        }

302
303
304
305
    def get_max_tokens_by_modality(
        self,
        model_config: "ModelConfig",
    ) -> Mapping[str, int]:
306
        """
307
        Get the maximum number of tokens from each modality
308
        for profiling the memory usage of a model.
309

310
        See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
311
312
313
314

        Note:
            This should be called after :meth:`init_mm_limits_per_prompt`.
        """
315
        mm_limits = self.get_mm_limits_per_prompt(model_config)
316

317
        return {
318
            key: mm_limits[key] * max_tokens_per_mm_item
319
320
            for key, max_tokens_per_mm_item in
            self.get_max_tokens_per_item_by_modality(model_config).items()
321
322
323
324
325
326
327
328
329
330
331
332
333
        }

    def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
        """
        Get the maximum number of multi-modal tokens
        for profiling the memory usage of a model.

        See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.

        Note:
            This should be called after :meth:`init_mm_limits_per_prompt`.
        """
        return sum(self.get_max_tokens_by_modality(model_config).values())
334
335
336

    def init_mm_limits_per_prompt(
        self,
337
        model_config: "ModelConfig",
338
339
340
341
342
343
344
345
346
347
    ) -> None:
        """
        Initialize the maximum number of multi-modal input instances for each
        modality that are allowed per prompt for a model class.
        """
        if model_config in self._limits_by_model:
            logger.warning(
                "`mm_limits` has already been set for model=%s, and will "
                "be overwritten by the new values.", model_config.model)

348
        multimodal_config = model_config.multimodal_config
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
        if multimodal_config is None:
            limits_per_plugin = self._disabled_limits_per_plugin
        else:
            config_limits_per_plugin = multimodal_config.limit_per_prompt

            extra_keys = config_limits_per_plugin.keys() - self._plugins.keys()
            if extra_keys:
                logger.warning(
                    "Detected extra keys in `--limit-mm-per-prompt` which "
                    "are not registered as multi-modal plugins: %s. "
                    "They will be ignored.", extra_keys)

            # NOTE: Currently the default is set to 1 for each plugin
            # TODO: Automatically determine the limits based on budget
            # once more models support multi-image inputs
            limits_per_plugin = {
365
                key: multimodal_config.get_limit_per_prompt(key)
366
367
368
369
370
371
372
                for key in self._plugins
            }

        self._limits_by_model[model_config] = limits_per_plugin

    def get_mm_limits_per_prompt(
        self,
373
        model_config: "ModelConfig",
374
375
376
377
378
379
380
    ) -> Mapping[str, int]:
        """
        Get the maximum number of multi-modal input instances for each modality
        that are allowed per prompt for a model class.

        Note:
            This should be called after :meth:`init_mm_limits_per_prompt`.
381
        """
382
        if self.has_processor(model_config):
383
            processor = self.create_processor(model_config, disable_cache=True)
384
385
386
            profiler = MultiModalProfiler(processor)
            return profiler.get_mm_limits()

387
        return self._limits_by_model[model_config]
388
389
390

    def register_processor(
        self,
391
392
393
394
        processor: MultiModalProcessorFactory[_I],
        *,
        info: ProcessingInfoFactory[_I],
        dummy_inputs: DummyInputsBuilderFactory[_I],
395
396
    ):
        """
397
398
        Register a multi-modal processor to a model class. The processor
        is constructed lazily, hence a factory method should be passed.
399
400
401
402
403

        When the model receives multi-modal data, the provided function is
        invoked to transform the data into a dictionary of model inputs.

        See also:
404
            :ref:`mm-processing`
405
406
407
        """

        def wrapper(model_cls: N) -> N:
408
            if self._processor_factories.contains(model_cls, strict=True):
409
                logger.warning(
410
                    "Model class %s already has a multi-modal processor "
411
412
413
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

414
415
416
417
418
            self._processor_factories[model_cls] = _ProcessorFactories(
                info=info,
                dummy_inputs=dummy_inputs,
                processor=processor,
            )
419
420
421
422
423

            return model_cls

        return wrapper

424
    def _get_model_cls(self, model_config: "ModelConfig"):
425
426
427
428
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
429
430
431
432
433
        return model_cls

    def has_processor(self, model_config: "ModelConfig") -> bool:
        """
        Test whether a multi-modal processor is defined for a specific model.
434
435
436

        See also:
            :ref:`mm-processing`
437
438
        """
        return self._get_model_cls(model_config) in self._processor_factories
439
440
441
442

    def create_processor(
        self,
        model_config: "ModelConfig",
443
        *,
444
        tokenizer: Optional[AnyTokenizer] = None,
445
        disable_cache: Optional[bool] = None,
446
    ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
447
448
        """
        Create a multi-modal processor for a specific model and tokenizer.
449
450
451

        See also:
            :ref:`mm-processing`
452
        """
453
454
        if tokenizer is None:
            tokenizer = cached_tokenizer_from_config(model_config)
455
456
457
        if disable_cache is None:
            disable_cache = model_config.disable_mm_preprocessor_cache

458
        model_cls = self._get_model_cls(model_config)
459
        factories = self._processor_factories[model_cls]
460
461

        ctx = InputProcessingContext(model_config, tokenizer)
462
        cache = None if disable_cache else self._processing_cache
463

464
        return factories.build_processor(ctx, cache=cache)
465
466
467
468
469

    def get_decoder_dummy_data(
        self,
        model_config: "ModelConfig",
        seq_len: int,
470
        mm_counts: Optional[Mapping[str, int]] = None,
471
472
473
474
475
476
477
478
    ) -> DummyDecoderData:
        """
        Create dummy data for profiling the memory usage of a model.

        The model is identified by ``model_config``.
        """
        processor = self.create_processor(model_config, disable_cache=True)
        profiler = MultiModalProfiler(processor)
479
        dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts)
480
481
482
483
484
485
486
487
488
489
490
491
492
493

        # Having more tokens is over-conservative but otherwise fine
        token_ids = dummy_data.prompt_token_ids
        if len(token_ids) < seq_len:
            raise AssertionError(
                f"Expected at least {seq_len} dummy tokens for profiling, "
                f"but found {len(token_ids)} tokens instead.")

        return dummy_data

    def get_encoder_dummy_data(
        self,
        model_config: "ModelConfig",
        seq_len: int,
494
        mm_counts: Optional[Mapping[str, int]] = None,
495
496
497
498
499
500
501
502
    ) -> DummyEncoderData:
        """
        Create dummy data for profiling the memory usage of a model.

        The model is identified by ``model_config``.
        """
        processor = self.create_processor(model_config, disable_cache=True)
        profiler = MultiModalProfiler(processor)
503
        dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts)
504
505
506
507
508
509
510
511
512

        # Having more tokens is over-conservative but otherwise fine
        token_ids = dummy_data.prompt_token_ids
        if len(token_ids) < seq_len:
            logger.warning_once(
                f"Expected at least {seq_len} dummy encoder tokens for "
                f"profiling, but found {len(token_ids)} tokens instead.")

        return dummy_data