registry.py 3.33 KB
Newer Older
1
import functools
2
from typing import Any, Optional, Sequence, Type, TypeVar
3

4
5
6
from torch import nn

from vllm.config import ModelConfig
7
8
from vllm.logger import init_logger

9
from .base import MultiModalData, MultiModalInputMapper, MultiModalPlugin
10
11
12
13
14
15
from .image import (ImageFeatureData, ImageFeaturePlugin, ImagePixelData,
                    ImagePixelPlugin)

logger = init_logger(__name__)

D = TypeVar("D", bound=MultiModalData)
16
N = TypeVar("N", bound=Type[nn.Module])
17
18
19
20


class MultiModalRegistry:
    """
21
    A registry to dispatch data processing
22
23
24
25
26
    according to its modality and the target model.
    """

    DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin())

27
28
29
30
31
    def __init__(
        self,
        *,
        plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS,
    ) -> None:
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
        self._plugins_by_data_type = {p.get_data_type(): p for p in plugins}

    def register_plugin(self, plugin: MultiModalPlugin[Any]) -> None:
        data_type = plugin.get_data_type()

        if data_type in self._plugins_by_data_type:
            logger.warning(
                "A plugin is already registered for data type %s, "
                "and will be overwritten by the new plugin %s.", data_type,
                plugin)

        self._plugins_by_data_type[data_type] = plugin

    def _get_plugin_for_data_type(self, data_type: Type[MultiModalData]):
        for typ in data_type.mro():
            plugin = self._plugins_by_data_type.get(typ)
            if plugin is not None:
                return plugin

        msg = f"Unknown multi-modal data type: {data_type}"
        raise NotImplementedError(msg)

54
55
56
57
58
    def register_input_mapper(
        self,
        data_type: Type[D],
        mapper: Optional[MultiModalInputMapper[D]] = None,
    ):
59
        """
60
        Register an input mapper for a specific modality to a model class.
61

62
        See :meth:`MultiModalPlugin.register_input_mapper` for more details.
63
64
        """
        return self._get_plugin_for_data_type(data_type) \
65
            .register_input_mapper(mapper)
66

67
68
69
70
    def register_image_pixel_input_mapper(
        self,
        mapper: Optional[MultiModalInputMapper[ImagePixelData]] = None,
    ):
71
        """
72
        Register an input mapper for image pixel data to a model class.
73

74
        See :meth:`MultiModalPlugin.register_input_mapper` for more details.
75
        """
76
        return self.register_input_mapper(ImagePixelData, mapper)
77

78
    def register_image_feature_input_mapper(
79
        self,
80
81
        mapper: Optional[MultiModalInputMapper[ImageFeatureData]] = None,
    ):
82
        """
83
        Register an input mapper for image feature data to a model class.
84

85
        See :meth:`MultiModalPlugin.register_input_mapper` for more details.
86
        """
87
        return self.register_input_mapper(ImageFeatureData, mapper)
88

89
    def map_input(self, model_config: ModelConfig, data: MultiModalData):
90
        """
91
        Apply an input mapper to a :class:`~MultiModalData` instance passed
92
93
        to the model.
        
94
        See :meth:`MultiModalPlugin.map_input` for more details.
95
96
        """
        return self._get_plugin_for_data_type(type(data)) \
97
            .map_input(model_config, data)
98

99
    def create_input_mapper(self, model_config: ModelConfig):
100
        """
101
        Create an input mapper (see :meth:`map_input`) for a specific model.
102
        """
103
        return functools.partial(self.map_input, model_config)