"vscode:/vscode.git/clone" did not exist on "451af742caf03d8e11c0256658cbb254486486fd"
base.py 10.9 KB
Newer Older
1
import sys
2
from abc import ABC, abstractmethod
3
from collections import UserDict, defaultdict
4
5
from typing import (Callable, Dict, List, Mapping, Optional, Tuple, Type,
                    TypedDict, TypeVar, Union, cast, final)
6

7
import numpy as np
8
9
10
11
import torch
import torch.types
from PIL import Image
from torch import nn
12
from typing_extensions import TypeAlias
13

14
15
from vllm.config import ModelConfig
from vllm.inputs import InputContext
16
from vllm.logger import init_logger
17
from vllm.utils import JSONTree, is_list_of, json_map_leaves
18
19
20

logger = init_logger(__name__)

21
NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor]
22
"""
23
Uses a list instead of a tensor if the dimensions of each element do not match.
24
25
"""

26
BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors]
27
28
29
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalInputs.batch`.
30
31
32
33
34
35
36
37
"""

if sys.version_info < (3, 9):
    # UserDict cannot be subscripted
    class _MultiModalInputsBase(UserDict):
        pass
else:

38
    class _MultiModalInputsBase(UserDict[str, NestedTensors]):
39
40
41
42
43
44
45
46
47
48
        pass


class MultiModalInputs(_MultiModalInputsBase):
    """
    A dictionary that represents the keyword arguments to
    :meth:`~torch.nn.Module.forward`.
    """

    @staticmethod
49
    def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
50
        """
51
        Recursively stacks lists of tensors when they all have the same shape.
52
        """
53
54
        if isinstance(nested_tensors, torch.Tensor):
            return nested_tensors
55

56
        stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
57
58
        if not is_list_of(stacked, torch.Tensor, check="all"):
            # Only tensors (not lists) can be stacked.
59
            return stacked
60

61
62
63
        tensors_ = cast(List[torch.Tensor], stacked)
        if any(t.shape != tensors_[0].shape for t in tensors_):
            # The tensors have incompatible shapes and can't be stacked.
64
            return tensors_
65

66
        return torch.stack(tensors_)
67
68

    @staticmethod
69
70
71
72
73
74
75
76
77
78
    def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs:
        """
        Batch multiple inputs together into a dictionary.

        The resulting dictionary has the same keys as the inputs.
        If the corresponding value from each input is a tensor and they all
        share the same shape, the output value is a single batched tensor;
        otherwise, the output value is a list containing the original value
        from each input.
        """
79
80
81
        if len(inputs_list) == 0:
            return {}

82
        item_lists: Dict[str, List[NestedTensors]] = defaultdict(list)
83
84

        for inputs in inputs_list:
85
86
87
            # For models that supports multiple modalities (e.g. Qwen2-VL),
            # different modalities will return different data keys,
            # so batch() should skip the same key check.
88
89
90
91
92

            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
93
            k: MultiModalInputs._try_stack(item_list)
94
            for k, item_list in item_lists.items()
95
        }
96
97
98
99
100
101
102

    @staticmethod
    def as_kwargs(
        batched_inputs: BatchedTensorInputs,
        *,
        device: torch.types.Device,
    ) -> BatchedTensorInputs:
103
104
105
106
107
108
109
110
        json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)

        json_mapped = json_map_leaves(
            lambda x: x.to(device, non_blocking=True),
            json_inputs,
        )

        return cast(BatchedTensorInputs, json_mapped)
111
112


113
114
115
116
117
118
119
120
121
122
123
124
_T = TypeVar("_T")

MultiModalData: TypeAlias = Union[_T, List[_T]]
"""
Either a single data instance, or a list of data instances.

The number of data instances allowed per modality is restricted by
`--limit-mm-per-prompt`.
"""


@final
125
class MultiModalDataBuiltins(TypedDict, total=False):
126
127
    """Modality types that are predefined by vLLM."""

128
129
    image: MultiModalData[Image.Image]
    """The input image(s)."""
130

131
132
    audio: MultiModalData[Tuple[np.ndarray, Union[int, float]]]
    """The input audio item(s) and corresponding sampling rate(s)."""
133

134

135
136
MultiModalDataDict = Union[MultiModalDataBuiltins,
                           Mapping[str, MultiModalData[object]]]
137
138
"""
A dictionary containing an item for each modality type to input.
139

140
141
142
143
144
Note:
    This dictionary also accepts modality keys defined outside
    :class:`MultiModalDataBuiltins` as long as a customized plugin is registered
    through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
    Read more on that :ref:`here <adding_multimodal_plugin>`.
145
"""
146

147
148
MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]],
                                 MultiModalInputs]
149
150
"""
Return a dictionary to be passed as keyword arguments to
151
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
152
153
154
155
156
157
158
159
160
161
and processors in HuggingFace Transformers.

If the data is not supported, throw :exc:`TypeError`.
"""

MultiModalTokensCalc = Union[int, Callable[[InputContext], int]]
"""
Calculate the maximum number of multimodal tokens input to the language
model. This does not include tokens that correspond to the input text.
"""
162

163
164
N = TypeVar("N", bound=Type[nn.Module])

165

166
class MultiModalPlugin(ABC):
167
168
169
170
171
172
173
174
    """
    Base class that defines data processing logic for a specific modality.

    In particular, we adopt a registry pattern to dispatch data processing
    according to the model being used (considering that different models may
    process the same data differently). This registry is in turn used by
    :class:`~MultiModalRegistry` which acts at a higher level
    (i.e., the modality of the data).
175
176
177

    See also:
        :ref:`adding_multimodal_plugin`
178
179
180
    """

    def __init__(self) -> None:
181
        self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {}
182
        self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {}
183
184

    @abstractmethod
185
    def get_data_key(self) -> str:
186
        """
187
        Get the data key corresponding to the modality.
188
189
190
191
        """
        raise NotImplementedError

    @abstractmethod
192
193
194
195
196
    def _default_input_mapper(
        self,
        ctx: InputContext,
        data: MultiModalData[object],
    ) -> MultiModalInputs:
197
198
        """
        Return a dictionary to be passed as keyword arguments to
199
        :meth:`~torch.nn.Module.forward`. This is similar in concept to
200
        tokenizers and processors in HuggingFace Transformers.
201
202

        If the data is not supported, throw :exc:`TypeError`.
203
204
205
        """
        raise NotImplementedError

206
207
    def register_input_mapper(
        self,
208
        mapper: Optional[MultiModalInputMapper] = None,
209
    ):
210
        """
211
        Register an input mapper to a model class.
212

213
        When the model receives input data that matches the modality served by
214
        this plugin (see :meth:`get_data_key`), the provided function is
215
        invoked to transform the data into a dictionary of model inputs.
216

217
218
219
        If `None` is provided, then the default input mapper is used instead.

        See also:
220
221
            - :ref:`input_processing_pipeline`
            - :ref:`enabling_multimodal_inputs`
222
223
224
        """

        def wrapper(model_cls: N) -> N:
225
            if model_cls in self._input_mappers:
226
                logger.warning(
227
                    "Model class %s already has an input mapper "
228
229
230
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

231
232
            self._input_mappers[model_cls] = mapper \
                or self._default_input_mapper
233
234
235
236
237

            return model_cls

        return wrapper

238
    def map_input(self, model_config: ModelConfig,
239
                  data: MultiModalData[object]) -> MultiModalInputs:
240
        """
241
242
        Transform the data into a dictionary of model inputs using the
        input mapper registered for that model.
243
244
245

        The model is identified by ``model_config``.

246
247
248
        Raises:
            TypeError: If the data type is not supported.

249
        See also:
250
251
            - :ref:`input_processing_pipeline`
            - :ref:`enabling_multimodal_inputs`
252
        """
253
254
255
256
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
257

258
259
260
        mapper = self._input_mappers.get(model_cls)
        if mapper is None:
            raise KeyError(f"No input mapper in {self} is registered for "
261
262
                           f"model class {model_cls.__name__}.")

263
        return mapper(InputContext(model_config), data)
264
265
266
267

    @abstractmethod
    def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
        """
268
269
        Calculate the maximum number of tokens, corresponding to a single
        instance of multimodal data, that are passed to the language model.
270
271
272
273
274
275
276
277
278
279
280
281
282
        """
        raise NotImplementedError

    def _validate_max_multimodal_tokens(self, max_mm_tokens: int):
        if max_mm_tokens < 1:
            raise ValueError("You should set the number of tokens to a "
                             f"positive integer. Found: {max_mm_tokens}")

    def register_max_multimodal_tokens(
        self,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
        """
283
284
285
        Register the maximum number of tokens, corresponding to a single
        instance of multimodal data, that are passed to the language model
        for a model class.
286
287
288
289

        If `None` is provided, then the default calculation is used instead.

        See also:
290
            :ref:`enabling_multimodal_inputs`
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
        """

        def wrapper(model_cls: N) -> N:
            if model_cls in self._max_mm_tokens:
                logger.warning(
                    "Model class %s already calculates maximum number of "
                    "tokens in %s. It is overwritten by the new one.",
                    model_cls, self)

            if isinstance(max_mm_tokens, int):
                self._validate_max_multimodal_tokens(max_mm_tokens)

            self._max_mm_tokens[model_cls] = max_mm_tokens \
                or self._default_max_multimodal_tokens

            return model_cls

        return wrapper

    def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int:
        """
        Get the maximum number of multi-modal tokens
        for profiling the memory usage of a model.

        If this registry is not applicable to the model, `0` is returned.

        The model is identified by ``model_config``.

        See also:
320
            :ref:`enabling_multimodal_inputs`
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)

        if model_cls not in self._input_mappers:
            return 0

        max_mm_tokens = self._max_mm_tokens.get(model_cls)
        if max_mm_tokens is None:
            raise KeyError(f"No maximum number of multi-modal tokens is given "
                           f"for model class {model_cls.__name__} in {self}.")

        if callable(max_mm_tokens):
            max_mm_tokens = max_mm_tokens(InputContext(model_config))

        self._validate_max_multimodal_tokens(max_mm_tokens)

        return max_mm_tokens