base.py 4.14 KB
Newer Older
1
from abc import ABC, abstractmethod
2
3
from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Type,
                    TypedDict, TypeVar, Union)
4

5
6
from vllm.config import ModelConfig
from vllm.inputs import InputContext
7
8
9
10
from vllm.logger import init_logger

if TYPE_CHECKING:
    import torch
11
    from PIL import Image
12
13
14
15
    from torch import nn

logger = init_logger(__name__)

16
N = TypeVar("N", bound=Type["nn.Module"])
17
18


19
20
class MultiModalDataBuiltins(TypedDict, total=False):
    image: "Image.Image"
21
22


23
24
25
MultiModalDataDict = Union[MultiModalDataBuiltins, Dict[str, Any]]
"""
A dictionary containing an item for each modality type to input.
26

27
28
29
30
The data belonging to each modality is converted into keyword arguments 
to the model by the corresponding mapper. By default, the mapper of 
the corresponding plugin with the same modality key is applied.
"""
31

32
33
MultiModalInputMapper = Callable[[InputContext, object], Dict[str,
                                                              "torch.Tensor"]]
34
"""Return a dictionary to be passed as keyword arguments to
35
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
36
37
38
and processors in HuggingFace Transformers."""


39
class MultiModalPlugin(ABC):
40
41
42
43
44
45
46
47
48
49
50
    """
    Base class that defines data processing logic for a specific modality.

    In particular, we adopt a registry pattern to dispatch data processing
    according to the model being used (considering that different models may
    process the same data differently). This registry is in turn used by
    :class:`~MultiModalRegistry` which acts at a higher level
    (i.e., the modality of the data).
    """

    def __init__(self) -> None:
51
        self._input_mappers: Dict[Type["nn.Module"],
52
                                  MultiModalInputMapper] = {}
53
54

    @abstractmethod
55
    def get_data_key(self) -> str:
56
        """
57
        Get the data key corresponding to the modality.
58
59
60
61
        """
        raise NotImplementedError

    @abstractmethod
62
    def _default_input_mapper(self, ctx: InputContext,
63
                              data: object) -> Dict[str, "torch.Tensor"]:
64
        """Return a dictionary to be passed as keyword arguments to
65
        :meth:`~torch.nn.Module.forward`. This is similar in concept to
66
67
68
69
        tokenizers and processors in HuggingFace Transformers.
        """
        raise NotImplementedError

70
71
    def register_input_mapper(
        self,
72
        mapper: Optional[MultiModalInputMapper] = None,
73
    ):
74
        """
75
        Register an input mapper to a model class.
76
        When the model receives input data that matches the modality served by
77
78
79
80
81
82
        this plugin (see :meth:`get_data_type`), the provided function is
        invoked to transform the data into a dictionary of model inputs.
        If `None` is provided, then the default input mapper is used instead.

        See also:
            :ref:`input_processing_pipeline`
83
84
85
        """

        def wrapper(model_cls: N) -> N:
86
            if model_cls in self._input_mappers:
87
                logger.warning(
88
                    "Model class %s already has an input mapper "
89
90
91
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

92
93
            self._input_mappers[model_cls] = mapper \
                or self._default_input_mapper
94
95
96
97
98

            return model_cls

        return wrapper

99
    def map_input(self, model_config: ModelConfig,
100
                  data: object) -> Dict[str, "torch.Tensor"]:
101
        """
102
        Apply an input mapper to a data passed
103
104
        to the model, transforming the data into a dictionary of model inputs.

105
106
        If the data is not something that the mapper expects, throws TypeError.

107
108
109
        The model is identified by ``model_config``.

        TODO: Add guide [ref: PR #5276]
110
        """
111
112
113
114
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
115

116
117
118
        mapper = self._input_mappers.get(model_cls)
        if mapper is None:
            raise KeyError(f"No input mapper in {self} is registered for "
119
120
                           f"model class {model_cls.__name__}.")

121
        return mapper(InputContext(model_config), data)