base.py 10.7 KB
Newer Older
1
import sys
2
from abc import ABC, abstractmethod
3
from collections import UserDict, defaultdict
4
5
from typing import (Callable, Dict, List, Mapping, Optional, Tuple, Type,
                    TypedDict, TypeVar, Union, cast, final)
6

7
import numpy as np
8
9
10
11
import torch
import torch.types
from PIL import Image
from torch import nn
12
from typing_extensions import TypeAlias
13

14
15
from vllm.config import ModelConfig
from vllm.inputs import InputContext
16
from vllm.logger import init_logger
17
from vllm.utils import json_map_leaves
18
19
20

logger = init_logger(__name__)

21
NestedTensors = Union[List["NestedTensors"], torch.Tensor]
22
"""
23
Uses a list instead of a tensor if the dimensions of each element do not match.
24
25
"""

26
BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors]
27
28
29
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalInputs.batch`.
30
31
32
33
34
35
36
37
"""

if sys.version_info < (3, 9):
    # UserDict cannot be subscripted
    class _MultiModalInputsBase(UserDict):
        pass
else:

38
    class _MultiModalInputsBase(UserDict[str, NestedTensors]):
39
40
41
42
43
44
45
46
47
48
        pass


class MultiModalInputs(_MultiModalInputsBase):
    """
    A dictionary that represents the keyword arguments to
    :meth:`~torch.nn.Module.forward`.
    """

    @staticmethod
49
    def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
50
        """
51
        Recursively stacks lists of tensors when they all have the same shape.
52
        """
53
54
        if isinstance(nested_tensors, torch.Tensor):
            return nested_tensors
55

56
57
58
        stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
        if any(isinstance(t, list) for t in stacked):
            return stacked
59

60
61
62
63
        tensors_ = cast(List[torch.Tensor], stacked)
        if any(t.shape != tensors_[0].shape for t in tensors_):
            # The tensors have incompatible shapes and can't be stacked.
            return tensors_
64

65
        return torch.stack(tensors_)
66
67

    @staticmethod
68
69
70
71
72
73
74
75
76
77
    def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs:
        """
        Batch multiple inputs together into a dictionary.

        The resulting dictionary has the same keys as the inputs.
        If the corresponding value from each input is a tensor and they all
        share the same shape, the output value is a single batched tensor;
        otherwise, the output value is a list containing the original value
        from each input.
        """
78
79
80
81
82
        if len(inputs_list) == 0:
            return {}

        keys = inputs_list[0].keys()

83
        item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
84
85
86
87
88
89
90
91
92
93

        for inputs in inputs_list:
            if inputs.keys() != keys:
                msg = f"Inputs do not share the same keys ({keys})"
                raise ValueError(msg)

            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
94
            k: MultiModalInputs._try_stack(item_list)
95
            for k, item_list in item_lists.items()
96
        }
97
98
99
100
101
102
103
104
105

    @staticmethod
    def as_kwargs(
        batched_inputs: BatchedTensorInputs,
        *,
        device: torch.types.Device,
    ) -> BatchedTensorInputs:
        return json_map_leaves(lambda x: x.to(device, non_blocking=True),
                               batched_inputs)
106
107


108
109
110
111
112
113
114
115
116
117
118
119
_T = TypeVar("_T")

MultiModalData: TypeAlias = Union[_T, List[_T]]
"""
Either a single data instance, or a list of data instances.

The number of data instances allowed per modality is restricted by
`--limit-mm-per-prompt`.
"""


@final
120
class MultiModalDataBuiltins(TypedDict, total=False):
121
122
    """Modality types that are predefined by vLLM."""

123
124
    image: MultiModalData[Image.Image]
    """The input image(s)."""
125

126
127
    audio: MultiModalData[Tuple[np.ndarray, Union[int, float]]]
    """The input audio item(s) and corresponding sampling rate(s)."""
128

129

130
131
MultiModalDataDict = Union[MultiModalDataBuiltins,
                           Mapping[str, MultiModalData[object]]]
132
133
"""
A dictionary containing an item for each modality type to input.
134

135
136
137
138
139
Note:
    This dictionary also accepts modality keys defined outside
    :class:`MultiModalDataBuiltins` as long as a customized plugin is registered
    through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
    Read more on that :ref:`here <adding_multimodal_plugin>`.
140
"""
141

142
143
MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]],
                                 MultiModalInputs]
144
145
"""
Return a dictionary to be passed as keyword arguments to
146
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
147
148
149
150
151
152
153
154
155
156
and processors in HuggingFace Transformers.

If the data is not supported, throw :exc:`TypeError`.
"""

MultiModalTokensCalc = Union[int, Callable[[InputContext], int]]
"""
Calculate the maximum number of multimodal tokens input to the language
model. This does not include tokens that correspond to the input text.
"""
157

158
159
N = TypeVar("N", bound=Type[nn.Module])

160

161
class MultiModalPlugin(ABC):
162
163
164
165
166
167
168
169
    """
    Base class that defines data processing logic for a specific modality.

    In particular, we adopt a registry pattern to dispatch data processing
    according to the model being used (considering that different models may
    process the same data differently). This registry is in turn used by
    :class:`~MultiModalRegistry` which acts at a higher level
    (i.e., the modality of the data).
170
171
172

    See also:
        :ref:`adding_multimodal_plugin`
173
174
175
    """

    def __init__(self) -> None:
176
        self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {}
177
        self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {}
178
179

    @abstractmethod
180
    def get_data_key(self) -> str:
181
        """
182
        Get the data key corresponding to the modality.
183
184
185
186
        """
        raise NotImplementedError

    @abstractmethod
187
188
189
190
191
    def _default_input_mapper(
        self,
        ctx: InputContext,
        data: MultiModalData[object],
    ) -> MultiModalInputs:
192
193
        """
        Return a dictionary to be passed as keyword arguments to
194
        :meth:`~torch.nn.Module.forward`. This is similar in concept to
195
        tokenizers and processors in HuggingFace Transformers.
196
197

        If the data is not supported, throw :exc:`TypeError`.
198
199
200
        """
        raise NotImplementedError

201
202
    def register_input_mapper(
        self,
203
        mapper: Optional[MultiModalInputMapper] = None,
204
    ):
205
        """
206
        Register an input mapper to a model class.
207

208
        When the model receives input data that matches the modality served by
209
        this plugin (see :meth:`get_data_key`), the provided function is
210
        invoked to transform the data into a dictionary of model inputs.
211

212
213
214
        If `None` is provided, then the default input mapper is used instead.

        See also:
215
216
            - :ref:`input_processing_pipeline`
            - :ref:`enabling_multimodal_inputs`
217
218
219
        """

        def wrapper(model_cls: N) -> N:
220
            if model_cls in self._input_mappers:
221
                logger.warning(
222
                    "Model class %s already has an input mapper "
223
224
225
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

226
227
            self._input_mappers[model_cls] = mapper \
                or self._default_input_mapper
228
229
230
231
232

            return model_cls

        return wrapper

233
    def map_input(self, model_config: ModelConfig,
234
                  data: MultiModalData[object]) -> MultiModalInputs:
235
        """
236
237
        Transform the data into a dictionary of model inputs using the
        input mapper registered for that model.
238
239
240

        The model is identified by ``model_config``.

241
242
243
        Raises:
            TypeError: If the data type is not supported.

244
        See also:
245
246
            - :ref:`input_processing_pipeline`
            - :ref:`enabling_multimodal_inputs`
247
        """
248
249
250
251
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
252

253
254
255
        mapper = self._input_mappers.get(model_cls)
        if mapper is None:
            raise KeyError(f"No input mapper in {self} is registered for "
256
257
                           f"model class {model_cls.__name__}.")

258
        return mapper(InputContext(model_config), data)
259
260
261
262

    @abstractmethod
    def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
        """
263
264
        Calculate the maximum number of tokens, corresponding to a single
        instance of multimodal data, that are passed to the language model.
265
266
267
268
269
270
271
272
273
274
275
276
277
        """
        raise NotImplementedError

    def _validate_max_multimodal_tokens(self, max_mm_tokens: int):
        if max_mm_tokens < 1:
            raise ValueError("You should set the number of tokens to a "
                             f"positive integer. Found: {max_mm_tokens}")

    def register_max_multimodal_tokens(
        self,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
        """
278
279
280
        Register the maximum number of tokens, corresponding to a single
        instance of multimodal data, that are passed to the language model
        for a model class.
281
282
283
284

        If `None` is provided, then the default calculation is used instead.

        See also:
285
            :ref:`enabling_multimodal_inputs`
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
        """

        def wrapper(model_cls: N) -> N:
            if model_cls in self._max_mm_tokens:
                logger.warning(
                    "Model class %s already calculates maximum number of "
                    "tokens in %s. It is overwritten by the new one.",
                    model_cls, self)

            if isinstance(max_mm_tokens, int):
                self._validate_max_multimodal_tokens(max_mm_tokens)

            self._max_mm_tokens[model_cls] = max_mm_tokens \
                or self._default_max_multimodal_tokens

            return model_cls

        return wrapper

    def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int:
        """
        Get the maximum number of multi-modal tokens
        for profiling the memory usage of a model.

        If this registry is not applicable to the model, `0` is returned.

        The model is identified by ``model_config``.

        See also:
315
            :ref:`enabling_multimodal_inputs`
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)

        if model_cls not in self._input_mappers:
            return 0

        max_mm_tokens = self._max_mm_tokens.get(model_cls)
        if max_mm_tokens is None:
            raise KeyError(f"No maximum number of multi-modal tokens is given "
                           f"for model class {model_cls.__name__} in {self}.")

        if callable(max_mm_tokens):
            max_mm_tokens = max_mm_tokens(InputContext(model_config))

        self._validate_max_multimodal_tokens(max_mm_tokens)

        return max_mm_tokens