base.py 10.6 KB
Newer Older
1
import sys
2
from abc import ABC, abstractmethod
3
from collections import UserDict, defaultdict
4
5
from typing import Any, Callable, Dict, List, Optional
from typing import Sequence as GenericSequence
6
from typing import Tuple, Type, TypedDict, TypeVar, Union, cast
7

8
import numpy as np
9
10
11
12
import torch
import torch.types
from PIL import Image
from torch import nn
13
from typing_extensions import TypeAlias
14

15
16
from vllm.config import ModelConfig
from vllm.inputs import InputContext
17
from vllm.logger import init_logger
18
from vllm.utils import JSONTree, json_map_leaves
19
20
21

logger = init_logger(__name__)

22
NestedTensors = Union[GenericSequence[torch.Tensor], torch.Tensor]
23
24
25
26
27
"""
Use a list instead of a tensor if the dimensions of each element do not match.
Currently only supports up to singly nested list of tensors.
"""

28
BatchedTensors: TypeAlias = JSONTree[torch.Tensor]
29
"""
30
31
32
33
34
35
36
37
A nested JSON structure of tensors which have been batched via
:meth:`MultiModalInputs.batch`.
"""

BatchedTensorInputs: TypeAlias = Dict[str, JSONTree[torch.Tensor]]
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalInputs.batch`.
38
39
40
41
42
43
44
45
"""

if sys.version_info < (3, 9):
    # UserDict cannot be subscripted
    class _MultiModalInputsBase(UserDict):
        pass
else:

46
    class _MultiModalInputsBase(UserDict[str, NestedTensors]):
47
48
49
50
51
52
53
54
55
56
        pass


class MultiModalInputs(_MultiModalInputsBase):
    """
    A dictionary that represents the keyword arguments to
    :meth:`~torch.nn.Module.forward`.
    """

    @staticmethod
57
    def _try_concat(tensors: List[NestedTensors]) -> BatchedTensors:
58
59
60
61
62
        """
        If each input tensor in the batch has the same shape, return a single
        batched tensor; otherwise, return a list of :class:`NestedTensors` with
        one element per item in the batch.
        """
63
64
        # may be list rather than tensors
        if isinstance(tensors[0], list):
65
            return [[t for t in tensor[0]]
66
                    for tensor in cast(List[List[torch.Tensor]], tensors)]
67
68
69
70

        tensors_ = cast(List[torch.Tensor], tensors)

        unbatched_shape = tensors_[0].shape[1:]
71

72
        for tensor in tensors_:
73
            if tensor.shape[1:] != unbatched_shape:
74
                return [tensor.squeeze(0) for tensor in tensors_]
75

76
        return torch.cat(tensors_, dim=0)
77
78

    @staticmethod
79
80
81
82
83
84
85
86
87
88
    def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs:
        """
        Batch multiple inputs together into a dictionary.

        The resulting dictionary has the same keys as the inputs.
        If the corresponding value from each input is a tensor and they all
        share the same shape, the output value is a single batched tensor;
        otherwise, the output value is a list containing the original value
        from each input.
        """
89
90
91
92
93
        if len(inputs_list) == 0:
            return {}

        keys = inputs_list[0].keys()

94
        item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
95
96
97
98
99
100
101
102
103
104

        for inputs in inputs_list:
            if inputs.keys() != keys:
                msg = f"Inputs do not share the same keys ({keys})"
                raise ValueError(msg)

            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
105
            k: MultiModalInputs._try_concat(item_list)
106
            for k, item_list in item_lists.items()
107
        }
108
109
110
111
112
113
114
115
116

    @staticmethod
    def as_kwargs(
        batched_inputs: BatchedTensorInputs,
        *,
        device: torch.types.Device,
    ) -> BatchedTensorInputs:
        return json_map_leaves(lambda x: x.to(device, non_blocking=True),
                               batched_inputs)
117
118


119
class MultiModalDataBuiltins(TypedDict, total=False):
120
121
    """Modality types that are predefined by vLLM."""

122
    image: Image.Image
123
    """The input image."""
124

125
126
127
    audio: Tuple[np.ndarray, Union[int, float]]
    """The input audio and its sampling rate."""

128

129
130
131
MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]]
"""
A dictionary containing an item for each modality type to input.
132

133
134
135
136
137
Note:
    This dictionary also accepts modality keys defined outside
    :class:`MultiModalDataBuiltins` as long as a customized plugin is registered
    through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
    Read more on that :ref:`here <adding_multimodal_plugin>`.
138
"""
139

140
MultiModalInputMapper = Callable[[InputContext, object], MultiModalInputs]
141
142
"""
Return a dictionary to be passed as keyword arguments to
143
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
144
145
146
147
148
149
150
151
152
153
and processors in HuggingFace Transformers.

If the data is not supported, throw :exc:`TypeError`.
"""

MultiModalTokensCalc = Union[int, Callable[[InputContext], int]]
"""
Calculate the maximum number of multimodal tokens input to the language
model. This does not include tokens that correspond to the input text.
"""
154

155
156
N = TypeVar("N", bound=Type[nn.Module])

157

158
class MultiModalPlugin(ABC):
159
160
161
162
163
164
165
166
    """
    Base class that defines data processing logic for a specific modality.

    In particular, we adopt a registry pattern to dispatch data processing
    according to the model being used (considering that different models may
    process the same data differently). This registry is in turn used by
    :class:`~MultiModalRegistry` which acts at a higher level
    (i.e., the modality of the data).
167
168
169

    See also:
        :ref:`adding_multimodal_plugin`
170
171
172
    """

    def __init__(self) -> None:
173
        self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {}
174
        self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {}
175
176

    @abstractmethod
177
    def get_data_key(self) -> str:
178
        """
179
        Get the data key corresponding to the modality.
180
181
182
183
        """
        raise NotImplementedError

    @abstractmethod
184
    def _default_input_mapper(self, ctx: InputContext,
185
                              data: object) -> MultiModalInputs:
186
187
        """
        Return a dictionary to be passed as keyword arguments to
188
        :meth:`~torch.nn.Module.forward`. This is similar in concept to
189
        tokenizers and processors in HuggingFace Transformers.
190
191

        If the data is not supported, throw :exc:`TypeError`.
192
193
194
        """
        raise NotImplementedError

195
196
    def register_input_mapper(
        self,
197
        mapper: Optional[MultiModalInputMapper] = None,
198
    ):
199
        """
200
        Register an input mapper to a model class.
201

202
        When the model receives input data that matches the modality served by
203
        this plugin (see :meth:`get_data_key`), the provided function is
204
        invoked to transform the data into a dictionary of model inputs.
205

206
207
208
        If `None` is provided, then the default input mapper is used instead.

        See also:
209
210
            - :ref:`input_processing_pipeline`
            - :ref:`enabling_multimodal_inputs`
211
212
213
        """

        def wrapper(model_cls: N) -> N:
214
            if model_cls in self._input_mappers:
215
                logger.warning(
216
                    "Model class %s already has an input mapper "
217
218
219
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

220
221
            self._input_mappers[model_cls] = mapper \
                or self._default_input_mapper
222
223
224
225
226

            return model_cls

        return wrapper

227
    def map_input(self, model_config: ModelConfig,
228
                  data: object) -> MultiModalInputs:
229
        """
230
231
        Transform the data into a dictionary of model inputs using the
        input mapper registered for that model.
232
233
234

        The model is identified by ``model_config``.

235
236
237
        Raises:
            TypeError: If the data type is not supported.

238
        See also:
239
240
            - :ref:`input_processing_pipeline`
            - :ref:`enabling_multimodal_inputs`
241
        """
242
243
244
245
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
246

247
248
249
        mapper = self._input_mappers.get(model_cls)
        if mapper is None:
            raise KeyError(f"No input mapper in {self} is registered for "
250
251
                           f"model class {model_cls.__name__}.")

252
        return mapper(InputContext(model_config), data)
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277

    @abstractmethod
    def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
        """
        Calculate the maximum number of multimodal tokens input to the language
        model. This does not include tokens that correspond to the input text.
        """
        raise NotImplementedError

    def _validate_max_multimodal_tokens(self, max_mm_tokens: int):
        if max_mm_tokens < 1:
            raise ValueError("You should set the number of tokens to a "
                             f"positive integer. Found: {max_mm_tokens}")

    def register_max_multimodal_tokens(
        self,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
        """
        Register the maximum number of multi-modal tokens input to the
        language model for a model class.

        If `None` is provided, then the default calculation is used instead.

        See also:
278
            :ref:`enabling_multimodal_inputs`
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
        """

        def wrapper(model_cls: N) -> N:
            if model_cls in self._max_mm_tokens:
                logger.warning(
                    "Model class %s already calculates maximum number of "
                    "tokens in %s. It is overwritten by the new one.",
                    model_cls, self)

            if isinstance(max_mm_tokens, int):
                self._validate_max_multimodal_tokens(max_mm_tokens)

            self._max_mm_tokens[model_cls] = max_mm_tokens \
                or self._default_max_multimodal_tokens

            return model_cls

        return wrapper

    def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int:
        """
        Get the maximum number of multi-modal tokens
        for profiling the memory usage of a model.

        If this registry is not applicable to the model, `0` is returned.

        The model is identified by ``model_config``.

        See also:
308
            :ref:`enabling_multimodal_inputs`
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)

        if model_cls not in self._input_mappers:
            return 0

        max_mm_tokens = self._max_mm_tokens.get(model_cls)
        if max_mm_tokens is None:
            raise KeyError(f"No maximum number of multi-modal tokens is given "
                           f"for model class {model_cls.__name__} in {self}.")

        if callable(max_mm_tokens):
            max_mm_tokens = max_mm_tokens(InputContext(model_config))

        self._validate_max_multimodal_tokens(max_mm_tokens)

        return max_mm_tokens