base.py 9.72 KB
Newer Older
1
import sys
2
from abc import ABC, abstractmethod
3
from collections import UserDict, defaultdict
4
5
6
from typing import Any, Callable, Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Type, TypedDict, TypeVar, Union, cast
7
8
9
10
11

import torch
import torch.types
from PIL import Image
from torch import nn
12

13
14
from vllm.config import ModelConfig
from vllm.inputs import InputContext
15
16
17
18
from vllm.logger import init_logger

logger = init_logger(__name__)

19
NestedTensors = Union[GenericSequence[torch.Tensor], torch.Tensor]
20
21
22
23
24
"""
Use a list instead of a tensor if the dimensions of each element do not match.
Currently only supports up to singly nested list of tensors.
"""

25
BatchedTensors = Union[GenericSequence[NestedTensors], NestedTensors]
26
27
"""
If each input tensor in the batch has the same size, this is a single batched
28
29
tensor; otherwise, this is a list of :class:`NestedTensors` with one element
per item in the batch.
30
31
32
33
34
35
36
37
"""

if sys.version_info < (3, 9):
    # UserDict cannot be subscripted
    class _MultiModalInputsBase(UserDict):
        pass
else:

38
    class _MultiModalInputsBase(UserDict[str, NestedTensors]):
39
40
41
42
43
44
45
46
47
48
49
        pass


class MultiModalInputs(_MultiModalInputsBase):
    """
    A dictionary that represents the keyword arguments to
    :meth:`~torch.nn.Module.forward`.
    """

    @staticmethod
    def try_concat(
50
        tensors: List[NestedTensors],
51
52
53
        *,
        device: torch.types.Device,
    ) -> BatchedTensors:
54
55
56
        # may be list rather than tensors
        if isinstance(tensors[0], list):
            return [[t.to(device=device) for t in tensor[0]]
57
                    for tensor in cast(List[List[torch.Tensor]], tensors)]
58
59
60
61

        tensors_ = cast(List[torch.Tensor], tensors)

        unbatched_shape = tensors_[0].shape[1:]
62

63
        for tensor in tensors_:
64
65
            if tensor.shape[1:] != unbatched_shape:
                return [
66
                    tensor.squeeze(0).to(device=device) for tensor in tensors_
67
68
                ]

69
        return torch.cat(tensors_, dim=0).to(device=device)
70
71
72
73
74
75
76
77
78
79
80
81

    @staticmethod
    def batch(
        inputs_list: List["MultiModalInputs"],
        device: torch.types.Device,
    ) -> Dict[str, BatchedTensors]:
        """Batch multiple inputs together into a dictionary."""
        if len(inputs_list) == 0:
            return {}

        keys = inputs_list[0].keys()

82
        item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
83
84
85
86
87
88
89
90
91
92
93
94
95

        for inputs in inputs_list:
            if inputs.keys() != keys:
                msg = f"Inputs do not share the same keys ({keys})"
                raise ValueError(msg)

            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
            k: MultiModalInputs.try_concat(item_list, device=device)
            for k, item_list in item_lists.items()
        }
96
97


98
class MultiModalDataBuiltins(TypedDict, total=False):
99
100
    """Modality types that are predefined by vLLM."""

101
    image: Image.Image
102
    """The input image."""
103
104


105
106
107
MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]]
"""
A dictionary containing an item for each modality type to input.
108

109
110
111
112
113
Note:
    This dictionary also accepts modality keys defined outside
    :class:`MultiModalDataBuiltins` as long as a customized plugin is registered
    through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
    Read more on that :ref:`here <adding_multimodal_plugin>`.
114
"""
115

116
MultiModalInputMapper = Callable[[InputContext, object], MultiModalInputs]
117
118
"""
Return a dictionary to be passed as keyword arguments to
119
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
120
121
122
123
124
125
126
127
128
129
and processors in HuggingFace Transformers.

If the data is not supported, throw :exc:`TypeError`.
"""

MultiModalTokensCalc = Union[int, Callable[[InputContext], int]]
"""
Calculate the maximum number of multimodal tokens input to the language
model. This does not include tokens that correspond to the input text.
"""
130

131
132
N = TypeVar("N", bound=Type[nn.Module])

133

134
class MultiModalPlugin(ABC):
135
136
137
138
139
140
141
142
    """
    Base class that defines data processing logic for a specific modality.

    In particular, we adopt a registry pattern to dispatch data processing
    according to the model being used (considering that different models may
    process the same data differently). This registry is in turn used by
    :class:`~MultiModalRegistry` which acts at a higher level
    (i.e., the modality of the data).
143
144
145

    See also:
        :ref:`adding_multimodal_plugin`
146
147
148
    """

    def __init__(self) -> None:
149
        self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {}
150
        self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {}
151
152

    @abstractmethod
153
    def get_data_key(self) -> str:
154
        """
155
        Get the data key corresponding to the modality.
156
157
158
159
        """
        raise NotImplementedError

    @abstractmethod
160
    def _default_input_mapper(self, ctx: InputContext,
161
                              data: object) -> MultiModalInputs:
162
163
        """
        Return a dictionary to be passed as keyword arguments to
164
        :meth:`~torch.nn.Module.forward`. This is similar in concept to
165
        tokenizers and processors in HuggingFace Transformers.
166
167

        If the data is not supported, throw :exc:`TypeError`.
168
169
170
        """
        raise NotImplementedError

171
172
    def register_input_mapper(
        self,
173
        mapper: Optional[MultiModalInputMapper] = None,
174
    ):
175
        """
176
        Register an input mapper to a model class.
177

178
        When the model receives input data that matches the modality served by
179
        this plugin (see :meth:`get_data_key`), the provided function is
180
        invoked to transform the data into a dictionary of model inputs.
181

182
183
184
        If `None` is provided, then the default input mapper is used instead.

        See also:
185
186
            - :ref:`input_processing_pipeline`
            - :ref:`enabling_multimodal_inputs`
187
188
189
        """

        def wrapper(model_cls: N) -> N:
190
            if model_cls in self._input_mappers:
191
                logger.warning(
192
                    "Model class %s already has an input mapper "
193
194
195
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

196
197
            self._input_mappers[model_cls] = mapper \
                or self._default_input_mapper
198
199
200
201
202

            return model_cls

        return wrapper

203
    def map_input(self, model_config: ModelConfig,
204
                  data: object) -> MultiModalInputs:
205
        """
206
207
        Transform the data into a dictionary of model inputs using the
        input mapper registered for that model.
208
209
210

        The model is identified by ``model_config``.

211
212
213
        Raises:
            TypeError: If the data type is not supported.

214
        See also:
215
216
            - :ref:`input_processing_pipeline`
            - :ref:`enabling_multimodal_inputs`
217
        """
218
219
220
221
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
222

223
224
225
        mapper = self._input_mappers.get(model_cls)
        if mapper is None:
            raise KeyError(f"No input mapper in {self} is registered for "
226
227
                           f"model class {model_cls.__name__}.")

228
        return mapper(InputContext(model_config), data)
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253

    @abstractmethod
    def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
        """
        Calculate the maximum number of multimodal tokens input to the language
        model. This does not include tokens that correspond to the input text.
        """
        raise NotImplementedError

    def _validate_max_multimodal_tokens(self, max_mm_tokens: int):
        if max_mm_tokens < 1:
            raise ValueError("You should set the number of tokens to a "
                             f"positive integer. Found: {max_mm_tokens}")

    def register_max_multimodal_tokens(
        self,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
        """
        Register the maximum number of multi-modal tokens input to the
        language model for a model class.

        If `None` is provided, then the default calculation is used instead.

        See also:
254
            :ref:`enabling_multimodal_inputs`
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
        """

        def wrapper(model_cls: N) -> N:
            if model_cls in self._max_mm_tokens:
                logger.warning(
                    "Model class %s already calculates maximum number of "
                    "tokens in %s. It is overwritten by the new one.",
                    model_cls, self)

            if isinstance(max_mm_tokens, int):
                self._validate_max_multimodal_tokens(max_mm_tokens)

            self._max_mm_tokens[model_cls] = max_mm_tokens \
                or self._default_max_multimodal_tokens

            return model_cls

        return wrapper

    def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int:
        """
        Get the maximum number of multi-modal tokens
        for profiling the memory usage of a model.

        If this registry is not applicable to the model, `0` is returned.

        The model is identified by ``model_config``.

        See also:
284
            :ref:`enabling_multimodal_inputs`
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)

        if model_cls not in self._input_mappers:
            return 0

        max_mm_tokens = self._max_mm_tokens.get(model_cls)
        if max_mm_tokens is None:
            raise KeyError(f"No maximum number of multi-modal tokens is given "
                           f"for model class {model_cls.__name__} in {self}.")

        if callable(max_mm_tokens):
            max_mm_tokens = max_mm_tokens(InputContext(model_config))

        self._validate_max_multimodal_tokens(max_mm_tokens)

        return max_mm_tokens