"tests/lora/test_whisper.py" did not exist on "2f4226fe5280b60c47b4f6f01d9b18ac9cda2038"
registry.py 8.96 KB
Newer Older
1
import functools
2
from collections import UserDict
3
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Sequence
4
5
6

from vllm.logger import init_logger

7
from .audio import AudioPlugin
8
from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalKwargs,
9
                   MultiModalPlugin, MultiModalTokensCalc, NestedTensors)
10
from .image import ImagePlugin
11
from .video import VideoPlugin
12

13
14
15
if TYPE_CHECKING:
    from vllm.config import ModelConfig

16
17
18
logger = init_logger(__name__)


19
20
21
22
23
24
class _MultiModalLimits(UserDict):
    """
    Wraps `_limits_by_model` for a more informative error message
    when attempting to access a model that does not exist.
    """

25
    def __getitem__(self, key: "ModelConfig") -> Dict[str, int]:
26
27
28
29
30
31
32
33
        try:
            return super().__getitem__(key)
        except KeyError as exc:
            msg = (f"Cannot find `mm_limits` for model={key.model}. Did you "
                   "forget to call `init_mm_limits_per_prompt`?")
            raise KeyError(msg) from exc


34
35
class MultiModalRegistry:
    """
36
37
    A registry that dispatches data processing to the
    :class:`~vllm.multimodal.MultiModalPlugin` for each modality.
38
39
    """

40
    DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), VideoPlugin())
41

42
    def __init__(
43
44
45
46
            self,
            *,
            plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
        self._plugins = {p.get_data_key(): p for p in plugins}
47

48
49
50
51
52
        # This is used for non-multimodal models
        self._disabled_limits_per_plugin = {k: 0 for k in self._plugins}

        self._limits_by_model = _MultiModalLimits()

53
    def register_plugin(self, plugin: MultiModalPlugin) -> None:
54
55
56
57
58
59
        """
        Register a multi-modal plugin so it can be recognized by vLLM.

        See also:
            :ref:`adding_multimodal_plugin`
        """
60
        data_type_key = plugin.get_data_key()
61

62
        if data_type_key in self._plugins:
63
64
            logger.warning(
                "A plugin is already registered for data type %s, "
65
                "and will be overwritten by the new plugin %s.", data_type_key,
66
67
                plugin)

68
        self._plugins[data_type_key] = plugin
69

70
71
72
73
    def _get_plugin(self, data_type_key: str):
        plugin = self._plugins.get(data_type_key)
        if plugin is not None:
            return plugin
74

75
        msg = f"Unknown multi-modal data type: {data_type_key}"
76
77
        raise NotImplementedError(msg)

78
    def register_input_mapper(
79
        self,
80
        data_type_key: str,
81
        mapper: Optional[MultiModalInputMapper] = None,
82
    ):
83
        """
84
        Register an input mapper for a specific modality to a model class.
85

86
        See :meth:`MultiModalPlugin.register_input_mapper` for more details.
87
        """
88
        return self._get_plugin(data_type_key).register_input_mapper(mapper)
89

90
    def register_image_input_mapper(
91
        self,
92
        mapper: Optional[MultiModalInputMapper] = None,
93
    ):
94
        """
95
        Register an input mapper for image data to a model class.
96

97
        See :meth:`MultiModalPlugin.register_input_mapper` for more details.
98
        """
99
        return self.register_input_mapper("image", mapper)
100

101
102
    def map_input(
        self,
103
        model_config: "ModelConfig",
104
105
        data: MultiModalDataDict,
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
106
    ) -> MultiModalKwargs:
107
        """
108
        Apply an input mapper to the data passed to the model.
109
110
111
112
113

        The data belonging to each modality is passed to the corresponding
        plugin which in turn converts the data into into keyword arguments
        via the input mapper registered for that model.

114
        See :meth:`MultiModalPlugin.map_input` for more details.
115
116
117

        Note:
            This should be called after :meth:`init_mm_limits_per_prompt`.
118
        """
119
        merged_dict: Dict[str, NestedTensors] = {}
120
121

        for data_key, data_value in data.items():
122
            plugin = self._get_plugin(data_key)
123

124
125
126
127
128
129
130
131
            num_items = len(data_value) if isinstance(data_value, list) else 1
            max_items = self._limits_by_model[model_config][data_key]
            if num_items > max_items:
                raise ValueError(
                    f"You set {data_key}={max_items} (or defaulted to 1) in "
                    f"`--limit-mm-per-prompt`, but found {num_items} items "
                    "in the same prompt.")

132
133
            input_dict = plugin.map_input(model_config, data_value,
                                          mm_processor_kwargs)
134
135
136
137
138
139
140
141
            for input_key, input_tensor in input_dict.items():
                if input_key in merged_dict:
                    raise ValueError(f"The input mappers (keys={set(data)}) "
                                     f"resulted in a conflicting keyword "
                                     f"argument to `forward()`: {input_key}")

                merged_dict[input_key] = input_tensor

142
        return MultiModalKwargs(merged_dict)
143

144
    def create_input_mapper(self, model_config: "ModelConfig"):
145
        """
146
        Create an input mapper (see :meth:`map_input`) for a specific model.
147
        """
148
149
150
151
152
153
154
155
156
        # NOTE - we currently make the assumption that if a model has multiple
        # supported modalities, they take the same kwargs. For the default,
        # this could be an issue in the future if it falls back to two HF
        # resources and we can't inspect the signature easily since it's
        # getting initialized through the autoclass.
        #
        # If this is a problem in the future, we should revisit it, but since
        # it potentially introduces a lot of complexity for a currently
        # uncommon case, we do not for simplicity of both use & implementation
157
        return functools.partial(self.map_input, model_config)
158

159
160
161
162
163
    def register_max_multimodal_tokens(
        self,
        data_type_key: str,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
164
        """
165
166
167
        Register the maximum number of tokens, corresponding to a single
        instance of multimodal data belonging to a specific modality, that are
        passed to the language model for a model class.
168
169
170
171
172
173
174
175
176
        """
        return self._get_plugin(data_type_key) \
            .register_max_multimodal_tokens(max_mm_tokens)

    def register_max_image_tokens(
        self,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
        """
177
178
        Register the maximum number of image tokens, corresponding to a single
        image, that are passed to the language model for a model class.
179
180
181
        """
        return self.register_max_multimodal_tokens("image", max_mm_tokens)

182
    def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
183
184
185
        """
        Get the maximum number of multi-modal tokens
        for profiling the memory usage of a model.
186

187
        See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
188
189
190
191
192
193
194
195
196
197
198
199

        Note:
            This should be called after :meth:`init_mm_limits_per_prompt`.
        """
        limits_per_plugin = self._limits_by_model[model_config]

        return sum((limits_per_plugin[key] *
                    plugin.get_max_multimodal_tokens(model_config))
                   for key, plugin in self._plugins.items())

    def init_mm_limits_per_prompt(
        self,
200
        model_config: "ModelConfig",
201
202
203
204
205
206
207
208
209
210
    ) -> None:
        """
        Initialize the maximum number of multi-modal input instances for each
        modality that are allowed per prompt for a model class.
        """
        if model_config in self._limits_by_model:
            logger.warning(
                "`mm_limits` has already been set for model=%s, and will "
                "be overwritten by the new values.", model_config.model)

211
        multimodal_config = model_config.multimodal_config
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        if multimodal_config is None:
            limits_per_plugin = self._disabled_limits_per_plugin
        else:
            config_limits_per_plugin = multimodal_config.limit_per_prompt

            extra_keys = config_limits_per_plugin.keys() - self._plugins.keys()
            if extra_keys:
                logger.warning(
                    "Detected extra keys in `--limit-mm-per-prompt` which "
                    "are not registered as multi-modal plugins: %s. "
                    "They will be ignored.", extra_keys)

            # NOTE: Currently the default is set to 1 for each plugin
            # TODO: Automatically determine the limits based on budget
            # once more models support multi-image inputs
            limits_per_plugin = {
                key: config_limits_per_plugin.get(key, 1)
                for key in self._plugins
            }

        self._limits_by_model[model_config] = limits_per_plugin

    def get_mm_limits_per_prompt(
        self,
236
        model_config: "ModelConfig",
237
238
239
240
241
242
243
    ) -> Mapping[str, int]:
        """
        Get the maximum number of multi-modal input instances for each modality
        that are allowed per prompt for a model class.

        Note:
            This should be called after :meth:`init_mm_limits_per_prompt`.
244
        """
245
        return self._limits_by_model[model_config]