base.py 4.11 KB
Newer Older
1
2
3
4
from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Callable, Dict, Generic, Optional, Type,
                    TypeVar)

5
6
from vllm.config import ModelConfig
from vllm.inputs import InputContext
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from vllm.logger import init_logger

if TYPE_CHECKING:
    import torch
    from torch import nn

logger = init_logger(__name__)


class MultiModalData:
    """
    Base class that contains multi-modal data.

    To add a new modality, add a new file under ``multimodal`` directory.

    In this new file, subclass :class:`~MultiModalData` and
    :class:`~MultiModalPlugin`.

    Finally, register the new plugin to
    :const:`vllm.multimodal.MULTIMODAL_REGISTRY`.
27
    This enables models to call :meth:`MultiModalRegistry.map_input` for
28
29
30
31
32
33
34
35
    the new modality.
    """
    pass


D = TypeVar("D", bound=MultiModalData)
N = TypeVar("N", bound=Type["nn.Module"])

36
MultiModalInputMapper = Callable[[InputContext, D], Dict[str, "torch.Tensor"]]
37
"""Return a dictionary to be passed as keyword arguments to
38
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
and processors in HuggingFace Transformers."""


class MultiModalPlugin(ABC, Generic[D]):
    """
    Base class that defines data processing logic for a specific modality.

    In particular, we adopt a registry pattern to dispatch data processing
    according to the model being used (considering that different models may
    process the same data differently). This registry is in turn used by
    :class:`~MultiModalRegistry` which acts at a higher level
    (i.e., the modality of the data).
    """

    def __init__(self) -> None:
54
55
        self._input_mappers: Dict[Type["nn.Module"],
                                  MultiModalInputMapper[D]] = {}
56
57
58
59
60
61
62
63
64
65

    @abstractmethod
    def get_data_type(self) -> Type[D]:
        """
        Get the modality (subclass of :class:`~MultiModalData`) served by
        this plugin.
        """
        raise NotImplementedError

    @abstractmethod
66
67
    def _default_input_mapper(self, ctx: InputContext,
                              data: D) -> Dict[str, "torch.Tensor"]:
68
        """Return a dictionary to be passed as keyword arguments to
69
        :meth:`~torch.nn.Module.forward`. This is similar in concept to
70
71
72
73
        tokenizers and processors in HuggingFace Transformers.
        """
        raise NotImplementedError

74
75
76
77
    def register_input_mapper(
        self,
        mapper: Optional[MultiModalInputMapper[D]] = None,
    ):
78
        """
79
        Register an input mapper to a model class.
80
81
        
        When the model receives input data that matches the modality served by
82
83
84
85
86
87
        this plugin (see :meth:`get_data_type`), the provided function is
        invoked to transform the data into a dictionary of model inputs.
        If `None` is provided, then the default input mapper is used instead.

        See also:
            :ref:`input_processing_pipeline`
88
89
90
        """

        def wrapper(model_cls: N) -> N:
91
            if model_cls in self._input_mappers:
92
                logger.warning(
93
                    "Model class %s already has an input mapper "
94
95
96
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

97
98
            self._input_mappers[model_cls] = mapper \
                or self._default_input_mapper
99
100
101
102
103

            return model_cls

        return wrapper

104
105
    def map_input(self, model_config: ModelConfig,
                  data: D) -> Dict[str, "torch.Tensor"]:
106
        """
107
108
109
110
111
112
        Apply an input mapper to a :class:`~MultiModalData` instance passed
        to the model, transforming the data into a dictionary of model inputs.

        The model is identified by ``model_config``.

        TODO: Add guide [ref: PR #5276]
113
        """
114
115
116
117
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
118

119
120
121
        mapper = self._input_mappers.get(model_cls)
        if mapper is None:
            raise KeyError(f"No input mapper in {self} is registered for "
122
123
                           f"model class {model_cls.__name__}.")

124
        return mapper(InputContext(model_config), data)