base.py 11.7 KB
Newer Older
1
import sys
2
from abc import ABC, abstractmethod
3
from collections import UserDict, defaultdict
4
5
from typing import (Callable, Dict, List, Mapping, Optional, Tuple, Type,
                    TypedDict, TypeVar, Union, cast, final)
6

7
import numpy as np
8
9
10
11
import torch
import torch.types
from PIL import Image
from torch import nn
12
from typing_extensions import TypeAlias
13

14
15
from vllm.config import ModelConfig
from vllm.inputs import InputContext
16
from vllm.logger import init_logger
17
18
from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of,
                        json_map_leaves)
19
20
21

logger = init_logger(__name__)

22
NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor]
23
"""
24
Uses a list instead of a tensor if the dimensions of each element do not match.
25
26
"""

27
BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors]
28
29
30
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalInputs.batch`.
31
32
33
34
35
36
37
38
"""

if sys.version_info < (3, 9):
    # UserDict cannot be subscripted
    class _MultiModalInputsBase(UserDict):
        pass
else:

39
    class _MultiModalInputsBase(UserDict[str, NestedTensors]):
40
41
42
43
44
45
46
47
48
49
        pass


class MultiModalInputs(_MultiModalInputsBase):
    """
    A dictionary that represents the keyword arguments to
    :meth:`~torch.nn.Module.forward`.
    """

    @staticmethod
50
    def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
51
        """
52
        Recursively stacks lists of tensors when they all have the same shape.
53
        """
54
55
        if isinstance(nested_tensors, torch.Tensor):
            return nested_tensors
56

57
        stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
58
59
        if not is_list_of(stacked, torch.Tensor, check="all"):
            # Only tensors (not lists) can be stacked.
60
            return stacked
61

62
63
64
        tensors_ = cast(List[torch.Tensor], stacked)
        if any(t.shape != tensors_[0].shape for t in tensors_):
            # The tensors have incompatible shapes and can't be stacked.
65
            return tensors_
66

67
        return torch.stack(tensors_)
68
69

    @staticmethod
70
71
72
73
74
75
76
77
78
79
    def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs:
        """
        Batch multiple inputs together into a dictionary.

        The resulting dictionary has the same keys as the inputs.
        If the corresponding value from each input is a tensor and they all
        share the same shape, the output value is a single batched tensor;
        otherwise, the output value is a list containing the original value
        from each input.
        """
80
81
82
        if len(inputs_list) == 0:
            return {}

83
        item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
84
85

        for inputs in inputs_list:
86
87
88
            # For models that supports multiple modalities (e.g. Qwen2-VL),
            # different modalities will return different data keys,
            # so batch() should skip the same key check.
89
90
91
92
93

            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
94
            k: MultiModalInputs._try_stack(item_list)
95
            for k, item_list in item_lists.items()
96
        }
97
98
99
100
101
102
103

    @staticmethod
    def as_kwargs(
        batched_inputs: BatchedTensorInputs,
        *,
        device: torch.types.Device,
    ) -> BatchedTensorInputs:
104
105
106
107
108
109
110
111
        json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)

        json_mapped = json_map_leaves(
            lambda x: x.to(device, non_blocking=True),
            json_inputs,
        )

        return cast(BatchedTensorInputs, json_mapped)
112
113


114
115
116
117
118
119
120
121
122
123
124
125
_T = TypeVar("_T")

MultiModalData: TypeAlias = Union[_T, List[_T]]
"""
Either a single data instance, or a list of data instances.

The number of data instances allowed per modality is restricted by
`--limit-mm-per-prompt`.
"""


@final
126
class MultiModalDataBuiltins(TypedDict, total=False):
127
128
    """Modality types that are predefined by vLLM."""

129
130
    image: MultiModalData[Image.Image]
    """The input image(s)."""
131

132
133
    audio: MultiModalData[Tuple[np.ndarray, Union[int, float]]]
    """The input audio item(s) and corresponding sampling rate(s)."""
134

135

136
137
MultiModalDataDict = Union[MultiModalDataBuiltins,
                           Mapping[str, MultiModalData[object]]]
138
139
"""
A dictionary containing an item for each modality type to input.
140

141
142
143
144
145
Note:
    This dictionary also accepts modality keys defined outside
    :class:`MultiModalDataBuiltins` as long as a customized plugin is registered
    through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
    Read more on that :ref:`here <adding_multimodal_plugin>`.
146
"""
147

148
149
MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]],
                                 MultiModalInputs]
150
151
"""
Return a dictionary to be passed as keyword arguments to
152
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
153
154
155
156
157
158
159
160
161
162
and processors in HuggingFace Transformers.

If the data is not supported, throw :exc:`TypeError`.
"""

MultiModalTokensCalc = Union[int, Callable[[InputContext], int]]
"""
Calculate the maximum number of multimodal tokens input to the language
model. This does not include tokens that correspond to the input text.
"""
163

164
165
N = TypeVar("N", bound=Type[nn.Module])

166

167
class MultiModalPlugin(ABC):
168
169
170
171
172
173
174
175
    """
    Base class that defines data processing logic for a specific modality.

    In particular, we adopt a registry pattern to dispatch data processing
    according to the model being used (considering that different models may
    process the same data differently). This registry is in turn used by
    :class:`~MultiModalRegistry` which acts at a higher level
    (i.e., the modality of the data).
176
177
178

    See also:
        :ref:`adding_multimodal_plugin`
179
180
181
    """

    def __init__(self) -> None:
182
        self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {}
183
        self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {}
184
185

    @abstractmethod
186
    def get_data_key(self) -> str:
187
        """
188
        Get the data key corresponding to the modality.
189
190
191
192
        """
        raise NotImplementedError

    @abstractmethod
193
194
195
196
197
    def _default_input_mapper(
        self,
        ctx: InputContext,
        data: MultiModalData[object],
    ) -> MultiModalInputs:
198
199
        """
        Return a dictionary to be passed as keyword arguments to
200
        :meth:`~torch.nn.Module.forward`. This is similar in concept to
201
        tokenizers and processors in HuggingFace Transformers.
202
203

        If the data is not supported, throw :exc:`TypeError`.
204
205
206
        """
        raise NotImplementedError

207
208
    def register_input_mapper(
        self,
209
        mapper: Optional[MultiModalInputMapper] = None,
210
    ):
211
        """
212
        Register an input mapper to a model class.
213

214
        When the model receives input data that matches the modality served by
215
        this plugin (see :meth:`get_data_key`), the provided function is
216
        invoked to transform the data into a dictionary of model inputs.
217

218
219
220
        If `None` is provided, then the default input mapper is used instead.

        See also:
221
222
            - :ref:`input_processing_pipeline`
            - :ref:`enabling_multimodal_inputs`
223
224
225
        """

        def wrapper(model_cls: N) -> N:
226
            if model_cls in self._input_mappers:
227
                logger.warning(
228
                    "Model class %s already has an input mapper "
229
230
231
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

232
233
            self._input_mappers[model_cls] = mapper \
                or self._default_input_mapper
234
235
236
237
238

            return model_cls

        return wrapper

239
    def map_input(self, model_config: ModelConfig,
240
                  data: MultiModalData[object]) -> MultiModalInputs:
241
        """
242
243
        Transform the data into a dictionary of model inputs using the
        input mapper registered for that model.
244
245
246

        The model is identified by ``model_config``.

247
248
249
        Raises:
            TypeError: If the data type is not supported.

250
        See also:
251
252
            - :ref:`input_processing_pipeline`
            - :ref:`enabling_multimodal_inputs`
253
        """
254
255
256
257
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
258

259
        mapper = self._input_mappers.get(model_cls)
260
261
262
263
264
265
266
267
268
        # Only get processor kwargs at mapping time if we are not using the
        # input mapper; no overrides are used on the default here because they
        # should be passed to the huggingface resource at initialization time.
        if mapper is not None and mapper != self._default_input_mapper:
            mm_processor_kwargs = get_allowed_kwarg_only_overrides(
                mapper, overrides=model_config.mm_processor_kwargs)
        else:
            mm_processor_kwargs = {}

269
270
        if mapper is None:
            raise KeyError(f"No input mapper in {self} is registered for "
271
272
                           f"model class {model_cls.__name__}.")

273
        return mapper(InputContext(model_config), data, **mm_processor_kwargs)
274
275
276
277

    @abstractmethod
    def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
        """
278
279
        Calculate the maximum number of tokens, corresponding to a single
        instance of multimodal data, that are passed to the language model.
280
281
282
283
284
285
286
287
288
289
290
291
292
        """
        raise NotImplementedError

    def _validate_max_multimodal_tokens(self, max_mm_tokens: int):
        if max_mm_tokens < 1:
            raise ValueError("You should set the number of tokens to a "
                             f"positive integer. Found: {max_mm_tokens}")

    def register_max_multimodal_tokens(
        self,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
        """
293
294
295
        Register the maximum number of tokens, corresponding to a single
        instance of multimodal data, that are passed to the language model
        for a model class.
296
297
298
299

        If `None` is provided, then the default calculation is used instead.

        See also:
300
            :ref:`enabling_multimodal_inputs`
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
        """

        def wrapper(model_cls: N) -> N:
            if model_cls in self._max_mm_tokens:
                logger.warning(
                    "Model class %s already calculates maximum number of "
                    "tokens in %s. It is overwritten by the new one.",
                    model_cls, self)

            if isinstance(max_mm_tokens, int):
                self._validate_max_multimodal_tokens(max_mm_tokens)

            self._max_mm_tokens[model_cls] = max_mm_tokens \
                or self._default_max_multimodal_tokens

            return model_cls

        return wrapper

    def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int:
        """
        Get the maximum number of multi-modal tokens
        for profiling the memory usage of a model.

        If this registry is not applicable to the model, `0` is returned.

        The model is identified by ``model_config``.

        See also:
330
            :ref:`enabling_multimodal_inputs`
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)

        if model_cls not in self._input_mappers:
            return 0

        max_mm_tokens = self._max_mm_tokens.get(model_cls)
        if max_mm_tokens is None:
            raise KeyError(f"No maximum number of multi-modal tokens is given "
                           f"for model class {model_cls.__name__} in {self}.")

        if callable(max_mm_tokens):
346
347
348
349
            mm_processor_kwargs = get_allowed_kwarg_only_overrides(
                max_mm_tokens, overrides=model_config.mm_processor_kwargs)
            max_mm_tokens = max_mm_tokens(InputContext(model_config),
                                          **mm_processor_kwargs)
350
351
352
353

        self._validate_max_multimodal_tokens(max_mm_tokens)

        return max_mm_tokens