base.py 10.9 KB
Newer Older
1
import sys
2
from abc import ABC, abstractmethod
3
from collections import UserDict, defaultdict
4
5
from typing import (Callable, Dict, List, Mapping, Optional, Tuple, Type,
                    TypedDict, TypeVar, Union, cast, final)
6

7
import numpy as np
8
9
10
11
import torch
import torch.types
from PIL import Image
from torch import nn
12
from typing_extensions import TypeAlias
13

14
15
from vllm.config import ModelConfig
from vllm.inputs import InputContext
16
from vllm.logger import init_logger
17
from vllm.utils import JSONTree, is_list_of, json_map_leaves
18
19
20

logger = init_logger(__name__)

21
NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor]
22
"""
23
Uses a list instead of a tensor if the dimensions of each element do not match.
24
25
"""

26
BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors]
27
28
29
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalInputs.batch`.
30
31
32
33
34
35
36
37
"""

if sys.version_info < (3, 9):
    # UserDict cannot be subscripted
    class _MultiModalInputsBase(UserDict):
        pass
else:

38
    class _MultiModalInputsBase(UserDict[str, NestedTensors]):
39
40
41
42
43
44
45
46
47
48
        pass


class MultiModalInputs(_MultiModalInputsBase):
    """
    A dictionary that represents the keyword arguments to
    :meth:`~torch.nn.Module.forward`.
    """

    @staticmethod
49
    def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
50
        """
51
        Recursively stacks lists of tensors when they all have the same shape.
52
        """
53
54
        if isinstance(nested_tensors, torch.Tensor):
            return nested_tensors
55

56
        stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
57
58
        if not is_list_of(stacked, torch.Tensor, check="all"):
            # Only tensors (not lists) can be stacked.
59
            return stacked
60

61
62
63
        tensors_ = cast(List[torch.Tensor], stacked)
        if any(t.shape != tensors_[0].shape for t in tensors_):
            # The tensors have incompatible shapes and can't be stacked.
64
            return tensors_
65

66
        return torch.stack(tensors_)
67
68

    @staticmethod
69
70
71
72
73
74
75
76
77
78
    def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs:
        """
        Batch multiple inputs together into a dictionary.

        The resulting dictionary has the same keys as the inputs.
        If the corresponding value from each input is a tensor and they all
        share the same shape, the output value is a single batched tensor;
        otherwise, the output value is a list containing the original value
        from each input.
        """
79
80
81
82
83
        if len(inputs_list) == 0:
            return {}

        keys = inputs_list[0].keys()

84
        item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
85
86
87
88
89
90
91
92
93
94

        for inputs in inputs_list:
            if inputs.keys() != keys:
                msg = f"Inputs do not share the same keys ({keys})"
                raise ValueError(msg)

            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
95
            k: MultiModalInputs._try_stack(item_list)
96
            for k, item_list in item_lists.items()
97
        }
98
99
100
101
102
103
104

    @staticmethod
    def as_kwargs(
        batched_inputs: BatchedTensorInputs,
        *,
        device: torch.types.Device,
    ) -> BatchedTensorInputs:
105
106
107
108
109
110
111
112
        json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)

        json_mapped = json_map_leaves(
            lambda x: x.to(device, non_blocking=True),
            json_inputs,
        )

        return cast(BatchedTensorInputs, json_mapped)
113
114


115
116
117
118
119
120
121
122
123
124
125
126
_T = TypeVar("_T")

MultiModalData: TypeAlias = Union[_T, List[_T]]
"""
Either a single data instance, or a list of data instances.

The number of data instances allowed per modality is restricted by
`--limit-mm-per-prompt`.
"""


@final
127
class MultiModalDataBuiltins(TypedDict, total=False):
128
129
    """Modality types that are predefined by vLLM."""

130
131
    image: MultiModalData[Image.Image]
    """The input image(s)."""
132

133
134
    audio: MultiModalData[Tuple[np.ndarray, Union[int, float]]]
    """The input audio item(s) and corresponding sampling rate(s)."""
135

136

137
138
MultiModalDataDict = Union[MultiModalDataBuiltins,
                           Mapping[str, MultiModalData[object]]]
139
140
"""
A dictionary containing an item for each modality type to input.
141

142
143
144
145
146
Note:
    This dictionary also accepts modality keys defined outside
    :class:`MultiModalDataBuiltins` as long as a customized plugin is registered
    through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
    Read more on that :ref:`here <adding_multimodal_plugin>`.
147
"""
148

149
150
MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]],
                                 MultiModalInputs]
151
152
"""
Return a dictionary to be passed as keyword arguments to
153
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
154
155
156
157
158
159
160
161
162
163
and processors in HuggingFace Transformers.

If the data is not supported, throw :exc:`TypeError`.
"""

MultiModalTokensCalc = Union[int, Callable[[InputContext], int]]
"""
Calculate the maximum number of multimodal tokens input to the language
model. This does not include tokens that correspond to the input text.
"""
164

165
166
N = TypeVar("N", bound=Type[nn.Module])

167

168
class MultiModalPlugin(ABC):
169
170
171
172
173
174
175
176
    """
    Base class that defines data processing logic for a specific modality.

    In particular, we adopt a registry pattern to dispatch data processing
    according to the model being used (considering that different models may
    process the same data differently). This registry is in turn used by
    :class:`~MultiModalRegistry` which acts at a higher level
    (i.e., the modality of the data).
177
178
179

    See also:
        :ref:`adding_multimodal_plugin`
180
181
182
    """

    def __init__(self) -> None:
183
        self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {}
184
        self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {}
185
186

    @abstractmethod
187
    def get_data_key(self) -> str:
188
        """
189
        Get the data key corresponding to the modality.
190
191
192
193
        """
        raise NotImplementedError

    @abstractmethod
194
195
196
197
198
    def _default_input_mapper(
        self,
        ctx: InputContext,
        data: MultiModalData[object],
    ) -> MultiModalInputs:
199
200
        """
        Return a dictionary to be passed as keyword arguments to
201
        :meth:`~torch.nn.Module.forward`. This is similar in concept to
202
        tokenizers and processors in HuggingFace Transformers.
203
204

        If the data is not supported, throw :exc:`TypeError`.
205
206
207
        """
        raise NotImplementedError

208
209
    def register_input_mapper(
        self,
210
        mapper: Optional[MultiModalInputMapper] = None,
211
    ):
212
        """
213
        Register an input mapper to a model class.
214

215
        When the model receives input data that matches the modality served by
216
        this plugin (see :meth:`get_data_key`), the provided function is
217
        invoked to transform the data into a dictionary of model inputs.
218

219
220
221
        If `None` is provided, then the default input mapper is used instead.

        See also:
222
223
            - :ref:`input_processing_pipeline`
            - :ref:`enabling_multimodal_inputs`
224
225
226
        """

        def wrapper(model_cls: N) -> N:
227
            if model_cls in self._input_mappers:
228
                logger.warning(
229
                    "Model class %s already has an input mapper "
230
231
232
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

233
234
            self._input_mappers[model_cls] = mapper \
                or self._default_input_mapper
235
236
237
238
239

            return model_cls

        return wrapper

240
    def map_input(self, model_config: ModelConfig,
241
                  data: MultiModalData[object]) -> MultiModalInputs:
242
        """
243
244
        Transform the data into a dictionary of model inputs using the
        input mapper registered for that model.
245
246
247

        The model is identified by ``model_config``.

248
249
250
        Raises:
            TypeError: If the data type is not supported.

251
        See also:
252
253
            - :ref:`input_processing_pipeline`
            - :ref:`enabling_multimodal_inputs`
254
        """
255
256
257
258
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
259

260
261
262
        mapper = self._input_mappers.get(model_cls)
        if mapper is None:
            raise KeyError(f"No input mapper in {self} is registered for "
263
264
                           f"model class {model_cls.__name__}.")

265
        return mapper(InputContext(model_config), data)
266
267
268
269

    @abstractmethod
    def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
        """
270
271
        Calculate the maximum number of tokens, corresponding to a single
        instance of multimodal data, that are passed to the language model.
272
273
274
275
276
277
278
279
280
281
282
283
284
        """
        raise NotImplementedError

    def _validate_max_multimodal_tokens(self, max_mm_tokens: int):
        if max_mm_tokens < 1:
            raise ValueError("You should set the number of tokens to a "
                             f"positive integer. Found: {max_mm_tokens}")

    def register_max_multimodal_tokens(
        self,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
        """
285
286
287
        Register the maximum number of tokens, corresponding to a single
        instance of multimodal data, that are passed to the language model
        for a model class.
288
289
290
291

        If `None` is provided, then the default calculation is used instead.

        See also:
292
            :ref:`enabling_multimodal_inputs`
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        """

        def wrapper(model_cls: N) -> N:
            if model_cls in self._max_mm_tokens:
                logger.warning(
                    "Model class %s already calculates maximum number of "
                    "tokens in %s. It is overwritten by the new one.",
                    model_cls, self)

            if isinstance(max_mm_tokens, int):
                self._validate_max_multimodal_tokens(max_mm_tokens)

            self._max_mm_tokens[model_cls] = max_mm_tokens \
                or self._default_max_multimodal_tokens

            return model_cls

        return wrapper

    def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int:
        """
        Get the maximum number of multi-modal tokens
        for profiling the memory usage of a model.

        If this registry is not applicable to the model, `0` is returned.

        The model is identified by ``model_config``.

        See also:
322
            :ref:`enabling_multimodal_inputs`
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)

        if model_cls not in self._input_mappers:
            return 0

        max_mm_tokens = self._max_mm_tokens.get(model_cls)
        if max_mm_tokens is None:
            raise KeyError(f"No maximum number of multi-modal tokens is given "
                           f"for model class {model_cls.__name__} in {self}.")

        if callable(max_mm_tokens):
            max_mm_tokens = max_mm_tokens(InputContext(model_config))

        self._validate_max_multimodal_tokens(max_mm_tokens)

        return max_mm_tokens