base.py 5.99 KB
Newer Older
1
import sys
2
from abc import ABC, abstractmethod
3
4
5
6
7
8
9
10
from collections import UserDict, defaultdict
from typing import (Any, Callable, Dict, List, Optional, Type, TypedDict,
                    TypeVar, Union)

import torch
import torch.types
from PIL import Image
from torch import nn
11

12
13
from vllm.config import ModelConfig
from vllm.inputs import InputContext
14
15
16
17
from vllm.logger import init_logger

logger = init_logger(__name__)

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
BatchedTensors = Union[torch.Tensor, List[torch.Tensor]]
"""
If each input tensor in the batch has the same size, this is a single batched
tensor; otherwise, this is a list of tensors with one element per batch.
"""

if sys.version_info < (3, 9):
    # UserDict cannot be subscripted
    class _MultiModalInputsBase(UserDict):
        pass
else:

    class _MultiModalInputsBase(UserDict[str, torch.Tensor]):
        pass


class MultiModalInputs(_MultiModalInputsBase):
    """
    A dictionary that represents the keyword arguments to
    :meth:`~torch.nn.Module.forward`.
    """

    @staticmethod
    def try_concat(
        tensors: List[torch.Tensor],
        *,
        device: torch.types.Device,
    ) -> BatchedTensors:
        # Avoid initializing CUDA too early
        import torch

        unbatched_shape = tensors[0].shape[1:]

        for tensor in tensors:
            if tensor.shape[1:] != unbatched_shape:
                return [
                    tensor.squeeze(0).to(device=device) for tensor in tensors
                ]

        return torch.cat(tensors, dim=0).to(device=device)

    @staticmethod
    def batch(
        inputs_list: List["MultiModalInputs"],
        device: torch.types.Device,
    ) -> Dict[str, BatchedTensors]:
        """Batch multiple inputs together into a dictionary."""
        if len(inputs_list) == 0:
            return {}

        keys = inputs_list[0].keys()

        item_lists: Dict[str, List[torch.Tensor]] = defaultdict(list)

        for inputs in inputs_list:
            if inputs.keys() != keys:
                msg = f"Inputs do not share the same keys ({keys})"
                raise ValueError(msg)

            for k, v in inputs.items():
                item_lists[k].append(v)

        return {
            k: MultiModalInputs.try_concat(item_list, device=device)
            for k, item_list in item_lists.items()
        }
84
85


86
class MultiModalDataBuiltins(TypedDict, total=False):
87
    image: Image.Image
88
89


90
91
92
MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]]
"""
A dictionary containing an item for each modality type to input.
93

94
95
96
97
The data belonging to each modality is converted into keyword arguments 
to the model by the corresponding mapper. By default, the mapper of 
the corresponding plugin with the same modality key is applied.
"""
98

99
MultiModalInputMapper = Callable[[InputContext, object], MultiModalInputs]
100
"""Return a dictionary to be passed as keyword arguments to
101
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
102
103
and processors in HuggingFace Transformers."""

104
105
N = TypeVar("N", bound=Type[nn.Module])

106

107
class MultiModalPlugin(ABC):
108
109
110
111
112
113
114
115
116
117
118
    """
    Base class that defines data processing logic for a specific modality.

    In particular, we adopt a registry pattern to dispatch data processing
    according to the model being used (considering that different models may
    process the same data differently). This registry is in turn used by
    :class:`~MultiModalRegistry` which acts at a higher level
    (i.e., the modality of the data).
    """

    def __init__(self) -> None:
119
        self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {}
120
121

    @abstractmethod
122
    def get_data_key(self) -> str:
123
        """
124
        Get the data key corresponding to the modality.
125
126
127
128
        """
        raise NotImplementedError

    @abstractmethod
129
    def _default_input_mapper(self, ctx: InputContext,
130
                              data: object) -> MultiModalInputs:
131
        """Return a dictionary to be passed as keyword arguments to
132
        :meth:`~torch.nn.Module.forward`. This is similar in concept to
133
134
135
136
        tokenizers and processors in HuggingFace Transformers.
        """
        raise NotImplementedError

137
138
    def register_input_mapper(
        self,
139
        mapper: Optional[MultiModalInputMapper] = None,
140
    ):
141
        """
142
        Register an input mapper to a model class.
143
        When the model receives input data that matches the modality served by
144
145
146
147
148
149
        this plugin (see :meth:`get_data_type`), the provided function is
        invoked to transform the data into a dictionary of model inputs.
        If `None` is provided, then the default input mapper is used instead.

        See also:
            :ref:`input_processing_pipeline`
150
            :ref:`adding_a_new_multimodal_model`
151
152
153
        """

        def wrapper(model_cls: N) -> N:
154
            if model_cls in self._input_mappers:
155
                logger.warning(
156
                    "Model class %s already has an input mapper "
157
158
159
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

160
161
            self._input_mappers[model_cls] = mapper \
                or self._default_input_mapper
162
163
164
165
166

            return model_cls

        return wrapper

167
    def map_input(self, model_config: ModelConfig,
168
                  data: object) -> MultiModalInputs:
169
        """
170
        Apply an input mapper to a data passed
171
172
        to the model, transforming the data into a dictionary of model inputs.

173
174
        If the data is not something that the mapper expects, throws TypeError.

175
176
        The model is identified by ``model_config``.

177
178
        See also:
            :ref:`adding_a_new_multimodal_model`
179
        """
180
181
182
183
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
184

185
186
187
        mapper = self._input_mappers.get(model_cls)
        if mapper is None:
            raise KeyError(f"No input mapper in {self} is registered for "
188
189
                           f"model class {model_cls.__name__}.")

190
        return mapper(InputContext(model_config), data)