base.py 9.63 KB
Newer Older
1
import sys
2
from abc import ABC, abstractmethod
3
4
from collections import UserDict, defaultdict
from typing import (Any, Callable, Dict, List, Optional, Type, TypedDict,
5
                    TypeVar, Union, cast)
6
7
8
9
10

import torch
import torch.types
from PIL import Image
from torch import nn
11

12
13
from vllm.config import ModelConfig
from vllm.inputs import InputContext
14
15
16
17
from vllm.logger import init_logger

logger = init_logger(__name__)

18
19
20
21
22
23
24
NestedTensors = Union[List[torch.Tensor], torch.Tensor]
"""
Use a list instead of a tensor if the dimensions of each element do not match.
Currently only supports up to singly nested list of tensors.
"""

BatchedTensors = Union[List[NestedTensors], NestedTensors]
25
26
"""
If each input tensor in the batch has the same size, this is a single batched
27
28
tensor; otherwise, this is a list of :class:`NestedTensors` with one element
per item in the batch.
29
30
31
32
33
34
35
36
"""

if sys.version_info < (3, 9):
    # UserDict cannot be subscripted
    class _MultiModalInputsBase(UserDict):
        pass
else:

37
    class _MultiModalInputsBase(UserDict[str, NestedTensors]):
38
39
40
41
42
43
44
45
46
47
48
        pass


class MultiModalInputs(_MultiModalInputsBase):
    """
    A dictionary that represents the keyword arguments to
    :meth:`~torch.nn.Module.forward`.
    """

    @staticmethod
    def try_concat(
49
        tensors: List[NestedTensors],
50
51
52
        *,
        device: torch.types.Device,
    ) -> BatchedTensors:
53
54
55
56
57
58
59
60
        # may be list rather than tensors
        if isinstance(tensors[0], list):
            return [[t.to(device=device) for t in tensor[0]]
                    for tensor in tensors]

        tensors_ = cast(List[torch.Tensor], tensors)

        unbatched_shape = tensors_[0].shape[1:]
61

62
        for tensor in tensors_:
63
64
            if tensor.shape[1:] != unbatched_shape:
                return [
65
                    tensor.squeeze(0).to(device=device) for tensor in tensors_
66
67
                ]

68
        return torch.cat(tensors_, dim=0).to(device=device)
69
70
71
72
73
74
75
76
77
78
79
80

    @staticmethod
    def batch(
        inputs_list: List["MultiModalInputs"],
        device: torch.types.Device,
    ) -> Dict[str, BatchedTensors]:
        """Batch multiple inputs together into a dictionary."""
        if len(inputs_list) == 0:
            return {}

        keys = inputs_list[0].keys()

81
        item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
82
83
84
85
86
87
88
89
90
91
92
93
94

        for inputs in inputs_list:
            if inputs.keys() != keys:
                msg = f"Inputs do not share the same keys ({keys})"
                raise ValueError(msg)

            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
            k: MultiModalInputs.try_concat(item_list, device=device)
            for k, item_list in item_lists.items()
        }
95
96


97
class MultiModalDataBuiltins(TypedDict, total=False):
98
99
    """Modality types that are predefined by vLLM."""

100
    image: Image.Image
101
    """The input image."""
102
103


104
105
106
MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]]
"""
A dictionary containing an item for each modality type to input.
107

108
109
110
111
112
Note:
    This dictionary also accepts modality keys defined outside
    :class:`MultiModalDataBuiltins` as long as a customized plugin is registered
    through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
    Read more on that :ref:`here <adding_multimodal_plugin>`.
113
"""
114

115
MultiModalInputMapper = Callable[[InputContext, object], MultiModalInputs]
116
117
"""
Return a dictionary to be passed as keyword arguments to
118
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
119
120
121
122
123
124
125
126
127
128
and processors in HuggingFace Transformers.

If the data is not supported, throw :exc:`TypeError`.
"""

MultiModalTokensCalc = Union[int, Callable[[InputContext], int]]
"""
Calculate the maximum number of multimodal tokens input to the language
model. This does not include tokens that correspond to the input text.
"""
129

130
131
N = TypeVar("N", bound=Type[nn.Module])

132

133
class MultiModalPlugin(ABC):
134
135
136
137
138
139
140
141
    """
    Base class that defines data processing logic for a specific modality.

    In particular, we adopt a registry pattern to dispatch data processing
    according to the model being used (considering that different models may
    process the same data differently). This registry is in turn used by
    :class:`~MultiModalRegistry` which acts at a higher level
    (i.e., the modality of the data).
142
143
144

    See also:
        :ref:`adding_multimodal_plugin`
145
146
147
    """

    def __init__(self) -> None:
148
        self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {}
149
        self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {}
150
151

    @abstractmethod
152
    def get_data_key(self) -> str:
153
        """
154
        Get the data key corresponding to the modality.
155
156
157
158
        """
        raise NotImplementedError

    @abstractmethod
159
    def _default_input_mapper(self, ctx: InputContext,
160
                              data: object) -> MultiModalInputs:
161
162
        """
        Return a dictionary to be passed as keyword arguments to
163
        :meth:`~torch.nn.Module.forward`. This is similar in concept to
164
        tokenizers and processors in HuggingFace Transformers.
165
166

        If the data is not supported, throw :exc:`TypeError`.
167
168
169
        """
        raise NotImplementedError

170
171
    def register_input_mapper(
        self,
172
        mapper: Optional[MultiModalInputMapper] = None,
173
    ):
174
        """
175
        Register an input mapper to a model class.
176

177
        When the model receives input data that matches the modality served by
178
        this plugin (see :meth:`get_data_key`), the provided function is
179
        invoked to transform the data into a dictionary of model inputs.
180

181
182
183
        If `None` is provided, then the default input mapper is used instead.

        See also:
184
185
            - :ref:`input_processing_pipeline`
            - :ref:`enabling_multimodal_inputs`
186
187
188
        """

        def wrapper(model_cls: N) -> N:
189
            if model_cls in self._input_mappers:
190
                logger.warning(
191
                    "Model class %s already has an input mapper "
192
193
194
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

195
196
            self._input_mappers[model_cls] = mapper \
                or self._default_input_mapper
197
198
199
200
201

            return model_cls

        return wrapper

202
    def map_input(self, model_config: ModelConfig,
203
                  data: object) -> MultiModalInputs:
204
        """
205
206
        Transform the data into a dictionary of model inputs using the
        input mapper registered for that model.
207
208
209

        The model is identified by ``model_config``.

210
211
212
        Raises:
            TypeError: If the data type is not supported.

213
        See also:
214
215
            - :ref:`input_processing_pipeline`
            - :ref:`enabling_multimodal_inputs`
216
        """
217
218
219
220
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
221

222
223
224
        mapper = self._input_mappers.get(model_cls)
        if mapper is None:
            raise KeyError(f"No input mapper in {self} is registered for "
225
226
                           f"model class {model_cls.__name__}.")

227
        return mapper(InputContext(model_config), data)
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252

    @abstractmethod
    def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
        """
        Calculate the maximum number of multimodal tokens input to the language
        model. This does not include tokens that correspond to the input text.
        """
        raise NotImplementedError

    def _validate_max_multimodal_tokens(self, max_mm_tokens: int):
        if max_mm_tokens < 1:
            raise ValueError("You should set the number of tokens to a "
                             f"positive integer. Found: {max_mm_tokens}")

    def register_max_multimodal_tokens(
        self,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
        """
        Register the maximum number of multi-modal tokens input to the
        language model for a model class.

        If `None` is provided, then the default calculation is used instead.

        See also:
253
            :ref:`enabling_multimodal_inputs`
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        """

        def wrapper(model_cls: N) -> N:
            if model_cls in self._max_mm_tokens:
                logger.warning(
                    "Model class %s already calculates maximum number of "
                    "tokens in %s. It is overwritten by the new one.",
                    model_cls, self)

            if isinstance(max_mm_tokens, int):
                self._validate_max_multimodal_tokens(max_mm_tokens)

            self._max_mm_tokens[model_cls] = max_mm_tokens \
                or self._default_max_multimodal_tokens

            return model_cls

        return wrapper

    def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int:
        """
        Get the maximum number of multi-modal tokens
        for profiling the memory usage of a model.

        If this registry is not applicable to the model, `0` is returned.

        The model is identified by ``model_config``.

        See also:
283
            :ref:`enabling_multimodal_inputs`
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)

        if model_cls not in self._input_mappers:
            return 0

        max_mm_tokens = self._max_mm_tokens.get(model_cls)
        if max_mm_tokens is None:
            raise KeyError(f"No maximum number of multi-modal tokens is given "
                           f"for model class {model_cls.__name__} in {self}.")

        if callable(max_mm_tokens):
            max_mm_tokens = max_mm_tokens(InputContext(model_config))

        self._validate_max_multimodal_tokens(max_mm_tokens)

        return max_mm_tokens