registry.py 11 KB
Newer Older
1
import functools
2
from collections import UserDict
3
from dataclasses import dataclass
4
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple,
5
                    Optional, Protocol, Type, cast)
6
7
8

from torch import nn
from transformers import PretrainedConfig
9
from typing_extensions import TypeVar
10
11

from vllm.logger import init_logger
12
13
from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once,
                        resolve_mm_processor_kwargs)
14

15
from .data import ProcessorInputs
16
17

if TYPE_CHECKING:
18
    from vllm.config import ModelConfig
19
20
    from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict,
                                 MultiModalRegistry)
21
22
23
24
    from vllm.sequence import SequenceData

logger = init_logger(__name__)

25
C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
26
27
28
29
30
31
32
33
34
35
36
37


@dataclass(frozen=True)
class InputContext:
    """
    Contains information about the model which may be used to
    modify the inputs.
    """

    model_config: "ModelConfig"
    """The configuration of the model."""

38
    def get_hf_config(self, hf_config_type: Type[C] = PretrainedConfig) -> C:
39
40
41
42
43
44
        """
        Get the HuggingFace configuration
        (:class:`transformers.PretrainedConfig`) of the model,
        additionally checking its type.

        Raises:
45
            TypeError: If the model is not of the specified type.
46
47
48
49
50
51
52
53
54
55
        """

        hf_config = self.model_config.hf_config
        if not isinstance(hf_config, hf_config_type):
            raise TypeError("Invalid type of HuggingFace config. "
                            f"Expected type: {hf_config_type}, but "
                            f"found type: {type(hf_config)}")

        return hf_config

56
57
58
59
60
61
62
    def get_hf_image_processor_config(self) -> Dict[str, Any]:
        """
        Get the HuggingFace image processor configuration of the model.
        """

        return self.model_config.hf_image_processor_config

63
64
65
66

N = TypeVar("N", bound=Type[nn.Module])


67
68
69
70
71
72
73
74
class DummyData(NamedTuple):
    """Dummy data used for profiling."""

    seq_data: "SequenceData"
    multi_modal_data: Optional["MultiModalDataDict"] = None
    multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None


75
76
77
78
79
80
81
class DummyDataFactory(Protocol):

    def __call__(
        self,
        ctx: InputContext,
        seq_len: int,
        mm_counts: Mapping[str, int],
82
        **mm_processor_kwargs: Any,
83
    ) -> DummyData:
84
85
86
87
88
        """
        Create dummy data to be inputted into the model.

        Note:
            :data:`InputProcessor` is not applied to the dummy data.
89
90
91
92

            The :code:`mm_processor_kwargs` are overrides provided at
            initialization time to values in the config whose values
            may affect the number of tokens per instance.
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        """
        ...


class _MultiModalCounts(UserDict):
    """
    Wraps `mm_counts` for a more informative error message
    when attempting to access a plugin that does not exist.
    """

    def __getitem__(self, key: str) -> int:
        try:
            return super().__getitem__(key)
        except KeyError as exc:
            msg = (f"There is no multi-modal plugin with the key: {key}. "
                   f"Available keys: {set(self.keys())}")
            raise KeyError(msg) from exc

111

112
InputProcessor = Callable[[InputContext, ProcessorInputs], ProcessorInputs]
113
114
115
116
117
118
119
120
121
122
123
124
"""Preprocess the inputs to the model."""


class InputRegistry:
    """
    A registry to dispatch data processing
    according to the target model.
    """

    def __init__(self) -> None:
        self._dummy_factories_by_model_type: Dict[Type[nn.Module],
                                                  DummyDataFactory] = {}
125
126
        self._dummy_encoder_factories_by_model_type: Dict[
            Type[nn.Module], DummyDataFactory] = {}
127
128
129
130
131
132
133
        self._input_processors_by_model_type: Dict[Type[nn.Module],
                                                   InputProcessor] = {}

    def _default_dummy_data_factory(
        self,
        ctx: InputContext,
        seq_len: int,
134
        mm_counts: Mapping[str, int],
135
    ) -> DummyData:
136
137
138
139
140
141
142
143
144
145
        """
        The default dummy data factory represents the longest possible text
        that can be inputted to the model.

        Note:
            :data:`InputProcessor` is not applied to the dummy data.
        """
        # Avoid circular import
        from vllm.sequence import SequenceData

146
        return DummyData(SequenceData.from_prompt_token_counts((0, seq_len)))
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169

    def register_dummy_data(self, factory: DummyDataFactory):
        """
        Register a dummy data factory to a model class.

        During memory profiling, the provided function is invoked to create
        dummy data to be inputted into the model. The resulting memory usage
        should be an upper bound of what the model would use at inference time.
        """

        def wrapper(model_cls: N) -> N:
            if model_cls in self._dummy_factories_by_model_type:
                logger.warning(
                    "Model class %s already has dummy data "
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

            self._dummy_factories_by_model_type[model_cls] = factory

            return model_cls

        return wrapper

170
171
172
173
    def _get_dummy_data_factory(self, model_cls: Type[nn.Module]):
        return self._dummy_factories_by_model_type \
            .get(model_cls, self._default_dummy_data_factory)

174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    def register_dummy_encoder_data(self, factory: DummyDataFactory):
        """
        Register a dummy encoder data factory to a model class

        This is similar to :meth:`~register_dummy_data`, but for encoder input.
        """

        def wrapper(model_cls: N) -> N:
            if model_cls in self._dummy_encoder_factories_by_model_type:
                logger.warning(
                    "Model class %s already has dummy encoder data "
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

            self._dummy_encoder_factories_by_model_type[model_cls] = factory

            return model_cls

        return wrapper

    def _get_dummy_encoder_data_factory(self, model_cls: Type[nn.Module]):
195
196
        return self._dummy_encoder_factories_by_model_type \
            .get(model_cls, self._default_dummy_data_factory)
197

198
199
200
201
202
    def dummy_data_for_profiling(
        self,
        model_config: "ModelConfig",
        seq_len: int,
        mm_registry: "MultiModalRegistry",
203
        is_encoder_data: bool = False,
204
    ) -> DummyData:
205
206
207
208
209
        """
        Create dummy data for profiling the memory usage of a model.

        The model is identified by ``model_config``.

210
        See also:
211
            :ref:`enabling_multimodal_inputs`
212
213
214
215

        Note:
            This should be called after
            :meth:`~MultiModalRegistry.init_mm_limits_per_prompt`.
216
217
218
219
220
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
221
222
223
224
        if is_encoder_data:
            dummy_factory = self._get_dummy_encoder_data_factory(model_cls)
        else:
            dummy_factory = self._get_dummy_data_factory(model_cls)
225
        mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
226
227
        mm_processor_kwargs = get_allowed_kwarg_only_overrides(
            dummy_factory, overrides=model_config.mm_processor_kwargs)
228

229
230
231
        dummy_data = dummy_factory(InputContext(model_config), seq_len,
                                   _MultiModalCounts(mm_counts),
                                   **mm_processor_kwargs)
232
233

        # Having more tokens is over-conservative but otherwise fine
234
        num_tokens = dummy_data.seq_data.prompt_token_ids
235
236
        if len(num_tokens) < seq_len:
            if is_encoder_data:
237
238
239
                print_warning_once(
                    f"Expected at least {seq_len} dummy encoder tokens for "
                    f"profiling, but found {len(num_tokens)} tokens instead.")
240
241
242
243
            else:
                raise AssertionError(
                    f"Expected at least {seq_len} dummy tokens for profiling, "
                    f"but found {len(num_tokens)} tokens instead.")
244
245
        if dummy_data.multi_modal_data is not None:
            for k, v in dummy_data.multi_modal_data.items():
246
247
248
249
250
251
                num_items = len(v) if isinstance(v, list) else 1
                num_expected = mm_counts[k]
                assert num_items >= num_expected, (
                    f"Expected at least {num_expected} dummy '{k}' instances "
                    f"for profiling, but found {num_items} instances instead.")

252
        return dummy_data
253

254
255
256
    def _default_input_processor(
        self,
        ctx: InputContext,
257
258
        inputs: ProcessorInputs,
    ) -> ProcessorInputs:
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
        """The default input processor is a no-op."""
        return inputs

    def register_input_processor(self, processor: InputProcessor):
        """
        Register an input processor to a model class.

        The provided function is invoked on each input to the model. This
        happens before :meth:`~vllm.multimodal.MultiModalRegistry.map_input`.

        See also:
            :ref:`input_processing_pipeline`
        """

        def wrapper(model_cls: N) -> N:
            if model_cls in self._input_processors_by_model_type:
                logger.warning(
                    "Model class %s already has input processor "
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

            self._input_processors_by_model_type[model_cls] = processor

            return model_cls

        return wrapper

286
287
288
289
    def _get_model_input_processor(self, model_cls: Type[nn.Module]):
        return self._input_processors_by_model_type \
            .get(model_cls, self._default_input_processor)

290
    def process_input(self, model_config: "ModelConfig",
291
                      inputs: ProcessorInputs) -> ProcessorInputs:
292
293
294
295
296
297
298
299
300
301
302
303
        """
        Apply an input processor to an instance of model inputs.

        The model is identified by ``model_config``.

        See also:
            :ref:`input_processing_pipeline`
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
304
        processor = self._get_model_input_processor(model_cls)
305

306
307
308
309
310
        # Handle multimodal processor kwargs with priority:
        #     Inference kwargs -> Init kwargs -> {}
        # If it's empty, it'll fall back to the default kwarg values
        mm_processor_kwargs = resolve_mm_processor_kwargs(
            model_config.mm_processor_kwargs,
311
            cast(Dict[str, Any], inputs.get("mm_processor_kwargs")),
312
313
            processor,
        )
314

315
316
        return processor(InputContext(model_config), inputs,
                         **mm_processor_kwargs)
317
318
319

    def create_input_processor(self, model_config: "ModelConfig"):
        """
320
        Create an input processor (see :meth:`_process_input`) for a
321
322
323
        specific model.
        """
        return functools.partial(self.process_input, model_config)