registry.py 4.06 KB
Newer Older
1
import functools
2
from typing import Dict, Optional, Sequence
3

4
import torch
5
6

from vllm.config import ModelConfig
7
8
from vllm.logger import init_logger

9
10
from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
                   MultiModalPlugin)
11
from .image import ImagePlugin
12
13
14
15
16
17

logger = init_logger(__name__)


class MultiModalRegistry:
    """
18
    A registry to dispatch data processing
19
    according to its modality and the target model.
20
21

    The registry handles both external and internal data input.
22
23
    """

24
    DEFAULT_PLUGINS = (ImagePlugin(), )
25

26
    def __init__(
27
28
29
30
            self,
            *,
            plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None:
        self._plugins = {p.get_data_key(): p for p in plugins}
31

32
33
    def register_plugin(self, plugin: MultiModalPlugin) -> None:
        data_type_key = plugin.get_data_key()
34

35
        if data_type_key in self._plugins:
36
37
            logger.warning(
                "A plugin is already registered for data type %s, "
38
                "and will be overwritten by the new plugin %s.", data_type_key,
39
40
                plugin)

41
        self._plugins[data_type_key] = plugin
42

43
44
45
46
    def _get_plugin(self, data_type_key: str):
        plugin = self._plugins.get(data_type_key)
        if plugin is not None:
            return plugin
47

48
        msg = f"Unknown multi-modal data type: {data_type_key}"
49
50
        raise NotImplementedError(msg)

51
    def register_image_input_mapper(
52
        self,
53
        mapper: Optional[MultiModalInputMapper] = None,
54
    ):
55
        """
56
        Register an input mapper for image data to a model class.
57

58
        See :meth:`MultiModalPlugin.register_input_mapper` for more details.
59
        """
60
61
62
        return self.register_input_mapper("image", mapper)

    def _process_input(self, key: str, value: object,
63
                       model_config: ModelConfig) -> MultiModalInputs:
64
65
66
67
68
        plugin = self._plugins.get(key)
        if plugin:
            return plugin.map_input(model_config, value)
        msg = f"Unknown multi-modal data type: {key}"
        raise NotImplementedError(msg)
69

70
    def register_input_mapper(
71
        self,
72
73
        data_type: str,
        mapper: Optional[MultiModalInputMapper] = None,
74
    ):
75
        """
76
        Register an input mapper for a specific modality to a model class.
77

78
        See :meth:`MultiModalPlugin.register_input_mapper` for more details.
79
        """
80
81
82
83
84
85
86
87
        plugin = self._plugins.get(data_type)
        if not plugin:
            msg = f"Unknown multi-modal data type: {data_type}"
            raise NotImplementedError(msg)
        return plugin.register_input_mapper(mapper)

    def register_image_input(self,
                             mapper: Optional[MultiModalInputMapper] = None):
88
        """
89
        Register an input mapper for image pixel data to a model class.
90

91
        See :meth:`MultiModalPlugin.register_input_mapper` for more details.
92
        """
93
        return self.register_input_mapper("image", mapper)
94

95
96
    def map_input(self, model_config: ModelConfig,
                  data: MultiModalDataDict) -> MultiModalInputs:
97
        """
98
        Apply an input mapper to the data passed to the model.
99
        
100
        See :meth:`MultiModalPlugin.map_input` for more details.
101
        """
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        merged_dict: Dict[str, torch.Tensor] = {}

        for data_key, data_value in data.items():
            input_dict = self._process_input(data_key, data_value,
                                             model_config)

            for input_key, input_tensor in input_dict.items():
                if input_key in merged_dict:
                    raise ValueError(f"The input mappers (keys={set(data)}) "
                                     f"resulted in a conflicting keyword "
                                     f"argument to `forward()`: {input_key}")

                merged_dict[input_key] = input_tensor

        return MultiModalInputs(merged_dict)
117

118
    def create_input_mapper(self, model_config: ModelConfig):
119
        """
120
        Create an input mapper (see :meth:`map_input`) for a specific model.
121
        """
122
        return functools.partial(self.map_input, model_config)