registry.py 4.75 KB
Newer Older
1
import functools
2
from typing import Dict, Optional, Sequence
3

4
import torch
5
6

from vllm.config import ModelConfig
7
8
from vllm.logger import init_logger

9
from .audio import AudioPlugin
10
from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
11
                   MultiModalPlugin, MultiModalTokensCalc)
12
from .image import ImagePlugin
13
14
15
16
17
18

logger = init_logger(__name__)


class MultiModalRegistry:
    """
19
20
    A registry that dispatches data processing to the
    :class:`~vllm.multimodal.MultiModalPlugin` for each modality.
21
22
    """

23
    DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin())
24

25
    def __init__(
26
27
28
29
            self,
            *,
            plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
        self._plugins = {p.get_data_key(): p for p in plugins}
30

31
    def register_plugin(self, plugin: MultiModalPlugin) -> None:
32
33
34
35
36
37
        """
        Register a multi-modal plugin so it can be recognized by vLLM.

        See also:
            :ref:`adding_multimodal_plugin`
        """
38
        data_type_key = plugin.get_data_key()
39

40
        if data_type_key in self._plugins:
41
42
            logger.warning(
                "A plugin is already registered for data type %s, "
43
                "and will be overwritten by the new plugin %s.", data_type_key,
44
45
                plugin)

46
        self._plugins[data_type_key] = plugin
47

48
49
50
51
    def _get_plugin(self, data_type_key: str):
        plugin = self._plugins.get(data_type_key)
        if plugin is not None:
            return plugin
52

53
        msg = f"Unknown multi-modal data type: {data_type_key}"
54
55
        raise NotImplementedError(msg)

56
    def register_input_mapper(
57
        self,
58
        data_type_key: str,
59
        mapper: Optional[MultiModalInputMapper] = None,
60
    ):
61
        """
62
        Register an input mapper for a specific modality to a model class.
63

64
        See :meth:`MultiModalPlugin.register_input_mapper` for more details.
65
        """
66
        return self._get_plugin(data_type_key).register_input_mapper(mapper)
67

68
    def register_image_input_mapper(
69
        self,
70
        mapper: Optional[MultiModalInputMapper] = None,
71
    ):
72
        """
73
        Register an input mapper for image data to a model class.
74

75
        See :meth:`MultiModalPlugin.register_input_mapper` for more details.
76
        """
77
        return self.register_input_mapper("image", mapper)
78

79
80
    def map_input(self, model_config: ModelConfig,
                  data: MultiModalDataDict) -> MultiModalInputs:
81
        """
82
        Apply an input mapper to the data passed to the model.
83
84
85
86
87

        The data belonging to each modality is passed to the corresponding
        plugin which in turn converts the data into into keyword arguments
        via the input mapper registered for that model.

88
        See :meth:`MultiModalPlugin.map_input` for more details.
89
        """
90
91
92
        merged_dict: Dict[str, torch.Tensor] = {}

        for data_key, data_value in data.items():
93
94
            input_dict = self._get_plugin(data_key) \
                .map_input(model_config, data_value)
95
96
97
98
99
100
101
102
103
104

            for input_key, input_tensor in input_dict.items():
                if input_key in merged_dict:
                    raise ValueError(f"The input mappers (keys={set(data)}) "
                                     f"resulted in a conflicting keyword "
                                     f"argument to `forward()`: {input_key}")

                merged_dict[input_key] = input_tensor

        return MultiModalInputs(merged_dict)
105

106
    def create_input_mapper(self, model_config: ModelConfig):
107
        """
108
        Create an input mapper (see :meth:`map_input`) for a specific model.
109
        """
110
        return functools.partial(self.map_input, model_config)
111

112
113
114
115
116
    def register_max_multimodal_tokens(
        self,
        data_type_key: str,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
117
        """
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        Register the maximum number of tokens, belonging to a
        specific modality, input to the language model for a model class.
        """
        return self._get_plugin(data_type_key) \
            .register_max_multimodal_tokens(max_mm_tokens)

    def register_max_image_tokens(
        self,
        max_mm_tokens: Optional[MultiModalTokensCalc] = None,
    ):
        """
        Register the maximum number of image tokens
        input to the language model for a model class.
        """
        return self.register_max_multimodal_tokens("image", max_mm_tokens)

    def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int:
        """
        Get the maximum number of multi-modal tokens
        for profiling the memory usage of a model.
        
        See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
140
        """
141
142
143
        return sum(
            plugin.get_max_multimodal_tokens(model_config)
            for plugin in self._plugins.values())