base.py 11.9 KB
Newer Older
1
import sys
2
from abc import ABC, abstractmethod
3
from collections import UserDict, defaultdict
4
5
from typing import (Callable, Dict, List, Mapping, Optional, Tuple, Type,
                    TypedDict, TypeVar, Union, cast, final)
6

7
import numpy as np
8
9
10
11
import torch
import torch.types
from PIL import Image
from torch import nn
12
from typing_extensions import TypeAlias
13

14
15
from vllm.config import ModelConfig
from vllm.inputs import InputContext
16
from vllm.logger import init_logger
17
18
from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of,
                        json_map_leaves)
19
20
21

logger = init_logger(__name__)

22
NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor]
23
"""
24
Uses a list instead of a tensor if the dimensions of each element do not match.
25
26
"""

27
BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors]
28
29
30
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalInputs.batch`.
31
32
33
34
35
36
37
38
"""

if sys.version_info < (3, 9):
    # UserDict cannot be subscripted
    class _MultiModalInputsBase(UserDict):
        pass
else:

39
    class _MultiModalInputsBase(UserDict[str, NestedTensors]):
40
41
42
43
44
45
46
47
48
49
        pass


class MultiModalInputs(_MultiModalInputsBase):
    """
    A dictionary that represents the keyword arguments to
    :meth:`~torch.nn.Module.forward`.
    """

    @staticmethod
50
    def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
51
        """
52
        Recursively stacks lists of tensors when they all have the same shape.
53
        """
54
55
        if isinstance(nested_tensors, torch.Tensor):
            return nested_tensors
56

57
58
59
60
61
62
        if isinstance(nested_tensors, np.ndarray):
            return torch.from_numpy(nested_tensors)

        if isinstance(nested_tensors, (int, float)):
            return torch.tensor(nested_tensors)

63
        stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
64
65
        if not is_list_of(stacked, torch.Tensor, check="all"):
            # Only tensors (not lists) can be stacked.
66
            return stacked
67

68
69
70
        tensors_ = cast(List[torch.Tensor], stacked)
        if any(t.shape != tensors_[0].shape for t in tensors_):
            # The tensors have incompatible shapes and can't be stacked.
71
            return tensors_
72

73
        return torch.stack(tensors_)
74
75

    @staticmethod
76
77
78
79
80
81
82
83
84
85
    def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs:
        """
        Batch multiple inputs together into a dictionary.

        The resulting dictionary has the same keys as the inputs.
        If the corresponding value from each input is a tensor and they all
        share the same shape, the output value is a single batched tensor;
        otherwise, the output value is a list containing the original value
        from each input.
        """
86
87
88
        if len(inputs_list) == 0:
            return {}

89
        item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
90
91

        for inputs in inputs_list:
92
93
94
            # For models that supports multiple modalities (e.g. Qwen2-VL),
            # different modalities will return different data keys,
            # so batch() should skip the same key check.
95
96
97
98
99

            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
100
            k: MultiModalInputs._try_stack(item_list)
101
            for k, item_list in item_lists.items()
102
        }
103
104
105
106
107
108
109

    @staticmethod
    def as_kwargs(
        batched_inputs: BatchedTensorInputs,
        *,
        device: torch.types.Device,
    ) -> BatchedTensorInputs:
110
111
112
113
114
115
116
117
        json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)

        json_mapped = json_map_leaves(
            lambda x: x.to(device, non_blocking=True),
            json_inputs,
        )

        return cast(BatchedTensorInputs, json_mapped)
118
119


120
121
122
123
124
125
126
127
128
129
130
131
_T = TypeVar("_T")

MultiModalData: TypeAlias = Union[_T, List[_T]]
"""
Either a single data instance, or a list of data instances.

The number of data instances allowed per modality is restricted by
`--limit-mm-per-prompt`.
"""


@final
132
class MultiModalDataBuiltins(TypedDict, total=False):
133
134
    """Modality types that are predefined by vLLM."""

135
136
    image: MultiModalData[Image.Image]
    """The input image(s)."""
137

138
139
    audio: MultiModalData[Tuple[np.ndarray, Union[int, float]]]
    """The input audio item(s) and corresponding sampling rate(s)."""
140

141

142
143
MultiModalDataDict = Union[MultiModalDataBuiltins,
                           Mapping[str, MultiModalData[object]]]
144
145
"""
A dictionary containing an item for each modality type to input.
146

147
148
149
150
151
Note:
    This dictionary also accepts modality keys defined outside
    :class:`MultiModalDataBuiltins` as long as a customized plugin is registered
    through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
    Read more on that :ref:`here <adding_multimodal_plugin>`.
152
"""
153

154
155
MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]],
                                 MultiModalInputs]
156
157
"""
Return a dictionary to be passed as keyword arguments to
158
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
159
160
161
162
163
164
165
166
167
168
and processors in HuggingFace Transformers.

If the data is not supported, throw :exc:`TypeError`.
"""

MultiModalTokensCalc = Union[int, Callable[[InputContext], int]]
"""
Calculate the maximum number of multimodal tokens input to the language
model. This does not include tokens that correspond to the input text.
"""
169

170
171
N = TypeVar("N", bound=Type[nn.Module])

172

173
class MultiModalPlugin(ABC):
174
175
176
177
178
179
180
181
    """
    Base class that defines data processing logic for a specific modality.

    In particular, we adopt a registry pattern to dispatch data processing
    according to the model being used (considering that different models may
    process the same data differently). This registry is in turn used by
    :class:`~MultiModalRegistry` which acts at a higher level
    (i.e., the modality of the data).
182
183
184

    See also:
        :ref:`adding_multimodal_plugin`
185
186
187
    """

    def __init__(self) -> None:
188
        self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {}
189
        self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {}
190
191

    @abstractmethod
192
    def get_data_key(self) -> str:
193
        """
194
        Get the data key corresponding to the modality.
195
196
197
198
        """
        raise NotImplementedError

    @abstractmethod
199
200
201
202
203
    def _default_input_mapper(
        self,
        ctx: InputContext,
        data: MultiModalData[object],
    ) -> MultiModalInputs:
204
205
        """
        Return a dictionary to be passed as keyword arguments to
206
        :meth:`~torch.nn.Module.forward`. This is similar in concept to
207
        tokenizers and processors in HuggingFace Transformers.
208
209

        If the data is not supported, throw :exc:`TypeError`.
210
211
212
        """
        raise NotImplementedError

213
214
    def register_input_mapper(
        self,
215
        mapper: Optional[MultiModalInputMapper] = None,
216
    ):
217
        """
218
        Register an input mapper to a model class.
219

220
        When the model receives input data that matches the modality served by
221
        this plugin (see :meth:`get_data_key`), the provided function is
222
        invoked to transform the data into a dictionary of model inputs.
223

224
225
226
        If `None` is provided, then the default input mapper is used instead.

        See also:
227
228
            - :ref:`input_processing_pipeline`
            - :ref:`enabling_multimodal_inputs`
229
230
231
        """

        def wrapper(model_cls: N) -> N:
232
            if model_cls in self._input_mappers:
233
                logger.warning(
234
                    "Model class %s already has an input mapper "
235
236
237
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

238
239
            self._input_mappers[model_cls] = mapper \
                or self._default_input_mapper
240
241
242
243
244

            return model_cls

        return wrapper

245
    def map_input(self, model_config: ModelConfig,
246
                  data: MultiModalData[object]) -> MultiModalInputs:
247
        """
248
249
        Transform the data into a dictionary of model inputs using the
        input mapper registered for that model.
250
251
252

        The model is identified by ``model_config``.

253
254
255
        Raises:
            TypeError: If the data type is not supported.

256
        See also:
257
258
            - :ref:`input_processing_pipeline`
            - :ref:`enabling_multimodal_inputs`
259
        """
260
261
262
263
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
264

265
        mapper = self._input_mappers.get(model_cls)
266
267
268
269
270
271
272
273
274
        # Only get processor kwargs at mapping time if we are not using the
        # input mapper; no overrides are used on the default here because they
        # should be passed to the huggingface resource at initialization time.
        if mapper is not None and mapper != self._default_input_mapper:
            mm_processor_kwargs = get_allowed_kwarg_only_overrides(
                mapper, overrides=model_config.mm_processor_kwargs)
        else:
            mm_processor_kwargs = {}

275
276
        if mapper is None:
            raise KeyError(f"No input mapper in {self} is registered for "
277
278
                           f"model class {model_cls.__name__}.")

279
        return mapper(InputContext(model_config), data, **mm_processor_kwargs)
280
281
282
283

    @abstractmethod
    def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
        """
284
285
        Calculate the maximum number of tokens, corresponding to a single
        instance of multimodal data, that are passed to the language model.
286
287
288
289
290
291
292
293
294
295
296
297
298
        """
        raise NotImplementedError

    def _validate_max_multimodal_tokens(self, max_mm_tokens: int):
        if max_mm_tokens < 1:
            raise ValueError("You should set the number of tokens to a "
                             f"positive integer. Found: {max_mm_tokens}")

    def register_max_multimodal_tokens(
        self,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
        """
299
300
301
        Register the maximum number of tokens, corresponding to a single
        instance of multimodal data, that are passed to the language model
        for a model class.
302
303
304
305

        If `None` is provided, then the default calculation is used instead.

        See also:
306
            :ref:`enabling_multimodal_inputs`
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        """

        def wrapper(model_cls: N) -> N:
            if model_cls in self._max_mm_tokens:
                logger.warning(
                    "Model class %s already calculates maximum number of "
                    "tokens in %s. It is overwritten by the new one.",
                    model_cls, self)

            if isinstance(max_mm_tokens, int):
                self._validate_max_multimodal_tokens(max_mm_tokens)

            self._max_mm_tokens[model_cls] = max_mm_tokens \
                or self._default_max_multimodal_tokens

            return model_cls

        return wrapper

    def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int:
        """
        Get the maximum number of multi-modal tokens
        for profiling the memory usage of a model.

        If this registry is not applicable to the model, `0` is returned.

        The model is identified by ``model_config``.

        See also:
336
            :ref:`enabling_multimodal_inputs`
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)

        if model_cls not in self._input_mappers:
            return 0

        max_mm_tokens = self._max_mm_tokens.get(model_cls)
        if max_mm_tokens is None:
            raise KeyError(f"No maximum number of multi-modal tokens is given "
                           f"for model class {model_cls.__name__} in {self}.")

        if callable(max_mm_tokens):
352
353
354
355
            mm_processor_kwargs = get_allowed_kwarg_only_overrides(
                max_mm_tokens, overrides=model_config.mm_processor_kwargs)
            max_mm_tokens = max_mm_tokens(InputContext(model_config),
                                          **mm_processor_kwargs)
356
357
358
359

        self._validate_max_multimodal_tokens(max_mm_tokens)

        return max_mm_tokens