registry.py 15.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import functools
4
from collections import UserDict
5
from dataclasses import dataclass
6
7
from typing import (TYPE_CHECKING, Any, Callable, Mapping, NamedTuple,
                    Optional, Protocol, Union)
8
9

from torch import nn
10
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
11
from typing_extensions import TypeVar, assert_never
12
13

from vllm.logger import init_logger
14
15
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer
16
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
17
                        resolve_mm_processor_kwargs)
18

19
20
from .data import ProcessorInputs, SingletonInputs
from .parse import is_encoder_decoder_inputs
21
22

if TYPE_CHECKING:
23
    from vllm.config import ModelConfig
24
25
    from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict,
                                 MultiModalRegistry)
26
27
28
29
    from vllm.sequence import SequenceData

logger = init_logger(__name__)

30
C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
31
P = TypeVar("P", bound=ProcessorMixin, default=ProcessorMixin)
32
33
34
35
36
37
38
39
40
41
42
43


@dataclass(frozen=True)
class InputContext:
    """
    Contains information about the model which may be used to
    modify the inputs.
    """

    model_config: "ModelConfig"
    """The configuration of the model."""

44
45
46
47
48
    def get_hf_config(
        self,
        typ: Union[type[C], tuple[type[C], ...]] = PretrainedConfig,
        /,
    ) -> C:
49
50
51
52
53
54
        """
        Get the HuggingFace configuration
        (:class:`transformers.PretrainedConfig`) of the model,
        additionally checking its type.

        Raises:
55
            TypeError: If the configuration is not of the specified type.
56
57
        """
        hf_config = self.model_config.hf_config
58
        if not isinstance(hf_config, typ):
59
            raise TypeError("Invalid type of HuggingFace config. "
60
                            f"Expected type: {typ}, but "
61
62
63
64
                            f"found type: {type(hf_config)}")

        return hf_config

65
    def get_hf_image_processor_config(self) -> dict[str, Any]:
66
67
68
69
70
        """
        Get the HuggingFace image processor configuration of the model.
        """
        return self.model_config.hf_image_processor_config

71
72
73
74
75
76
77
78
79
80
81
82
83
    def get_mm_config(self):
        """
        Get the multimodal config of the model.

        Raises:
            RuntimeError: If the model is not a multimodal model.
        """
        mm_config = self.model_config.multimodal_config
        if mm_config is None:
            raise RuntimeError("Not a multimodal model")

        return mm_config

84
85
86
87
88
89
90
91
92
93
94
95
96
97
    def get_hf_processor(
        self,
        typ: Union[type[P], tuple[type[P], ...]] = ProcessorMixin,
        /,
        **kwargs: object,
    ) -> P:
        """
        Get the HuggingFace processor
        (:class:`transformers.ProcessorMixin`) of the model,
        additionally checking its type.

        Raises:
            TypeError: If the processor is not of the specified type.
        """
98
99
100
101
102
103
        base_kwargs = self.model_config.mm_processor_kwargs
        if base_kwargs is None:
            base_kwargs = {}

        merged_kwargs = {**base_kwargs, **kwargs}

104
105
106
        if isinstance(typ, type):
            merged_kwargs["processor_cls"] = typ

107
        hf_processor = cached_get_processor(
108
109
110
111
            self.model_config.model,
            trust_remote_code=self.model_config.trust_remote_code,
            **merged_kwargs,
        )
112
113
114
115
116
117
        if not isinstance(hf_processor, typ):
            raise TypeError("Invalid type of HuggingFace processor. "
                            f"Expected type: {typ}, but "
                            f"found type: {type(hf_processor)}")

        return hf_processor
118

119

120
121
122
123
124
@dataclass(frozen=True)
class InputProcessingContext(InputContext):
    tokenizer: AnyTokenizer
    """The tokenizer used to tokenize the inputs."""

125
126
127
128
129
130
131
132
133
134
    def get_hf_processor(
        self,
        typ: Union[type[P], tuple[type[P], ...]] = ProcessorMixin,
        /,
        **kwargs: object,
    ) -> P:
        return super().get_hf_processor(
            typ,
            tokenizer=self.tokenizer,
            **kwargs,
135
136
        )

137
    def call_hf_processor(
138
139
        self,
        hf_processor: ProcessorMixin,
140
141
        data: Mapping[str, object],
        kwargs: Mapping[str, object] = {},
142
    ) -> BatchFeature:
143
144
145
146
        """
        Call :code:`hf_processor` on the prompt :code:`data`
        (text, image, audio...) with configurable options :code:`kwargs`.
        """
147
148
149
150
151
152
        assert callable(hf_processor)

        base_kwargs = self.model_config.mm_processor_kwargs
        if base_kwargs is None:
            base_kwargs = {}

153
        merged_kwargs = resolve_mm_processor_kwargs(
154
            base_kwargs,
155
            kwargs,
156
            hf_processor,
157
158
            requires_kw_only=False,
            allow_var_kwargs=True,
159
        )
160

161
        try:
162
            return hf_processor(**data, **merged_kwargs, return_tensors="pt")
163
164
165
166
167
168
        except Exception as exc:
            msg = (f"Failed to apply {type(hf_processor).__name__} "
                   f"on data={data} with kwargs={merged_kwargs}")

            raise RuntimeError(msg) from exc

169

170
N = TypeVar("N", bound=type[nn.Module])
171
172


173
174
175
176
177
178
179
180
class DummyData(NamedTuple):
    """Dummy data used for profiling."""

    seq_data: "SequenceData"
    multi_modal_data: Optional["MultiModalDataDict"] = None
    multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None


181
182
183
184
185
186
187
class DummyDataFactory(Protocol):

    def __call__(
        self,
        ctx: InputContext,
        seq_len: int,
        mm_counts: Mapping[str, int],
188
        **mm_processor_kwargs: Any,
189
    ) -> DummyData:
190
191
192
193
194
        """
        Create dummy data to be inputted into the model.

        Note:
            :data:`InputProcessor` is not applied to the dummy data.
195
196
197
198

            The :code:`mm_processor_kwargs` are overrides provided at
            initialization time to values in the config whose values
            may affect the number of tokens per instance.
199
200
201
202
        """
        ...


203
class _MultiModalCounts(UserDict[str, int]):
204
205
206
207
208
209
210
211
212
213
214
215
216
    """
    Wraps `mm_counts` for a more informative error message
    when attempting to access a plugin that does not exist.
    """

    def __getitem__(self, key: str) -> int:
        try:
            return super().__getitem__(key)
        except KeyError as exc:
            msg = (f"There is no multi-modal plugin with the key: {key}. "
                   f"Available keys: {set(self.keys())}")
            raise KeyError(msg) from exc

217

218
InputProcessor = Callable[[InputContext, ProcessorInputs], ProcessorInputs]
219
220
221
222
223
224
225
226
227
228
"""Preprocess the inputs to the model."""


class InputRegistry:
    """
    A registry to dispatch data processing
    according to the target model.
    """

    def __init__(self) -> None:
229
230
231
232
233
234
        self._dummy_factories_by_model_type = \
            ClassRegistry[nn.Module, DummyDataFactory]()
        self._dummy_encoder_factories_by_model_type = \
            ClassRegistry[nn.Module, DummyDataFactory]()
        self._input_processors_by_model_type = \
            ClassRegistry[nn.Module, InputProcessor]()
235
236
237
238
239

    def _default_dummy_data_factory(
        self,
        ctx: InputContext,
        seq_len: int,
240
        mm_counts: Mapping[str, int],
241
    ) -> DummyData:
242
243
244
245
246
247
248
249
250
251
        """
        The default dummy data factory represents the longest possible text
        that can be inputted to the model.

        Note:
            :data:`InputProcessor` is not applied to the dummy data.
        """
        # Avoid circular import
        from vllm.sequence import SequenceData

252
        return DummyData(SequenceData.from_prompt_token_counts((0, seq_len)))
253
254
255
256
257
258
259
260
261
262
263

    def register_dummy_data(self, factory: DummyDataFactory):
        """
        Register a dummy data factory to a model class.

        During memory profiling, the provided function is invoked to create
        dummy data to be inputted into the model. The resulting memory usage
        should be an upper bound of what the model would use at inference time.
        """

        def wrapper(model_cls: N) -> N:
264
265
            if self._dummy_factories_by_model_type.contains(model_cls,
                                                            strict=True):
266
267
268
269
270
271
272
273
274
275
276
                logger.warning(
                    "Model class %s already has dummy data "
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

            self._dummy_factories_by_model_type[model_cls] = factory

            return model_cls

        return wrapper

277
    def _get_dummy_data_factory(self, model_cls: type[nn.Module]):
278
279
280
        return self._dummy_factories_by_model_type \
            .get(model_cls, self._default_dummy_data_factory)

281
282
283
284
285
286
287
288
    def register_dummy_encoder_data(self, factory: DummyDataFactory):
        """
        Register a dummy encoder data factory to a model class

        This is similar to :meth:`~register_dummy_data`, but for encoder input.
        """

        def wrapper(model_cls: N) -> N:
289
290
            if self._dummy_encoder_factories_by_model_type.contains(
                    model_cls, strict=True):
291
292
293
294
295
296
297
298
299
300
301
                logger.warning(
                    "Model class %s already has dummy encoder data "
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

            self._dummy_encoder_factories_by_model_type[model_cls] = factory

            return model_cls

        return wrapper

302
    def _get_dummy_encoder_data_factory(self, model_cls: type[nn.Module]):
303
304
        return self._dummy_encoder_factories_by_model_type \
            .get(model_cls, self._default_dummy_data_factory)
305

306
307
308
309
310
    def dummy_data_for_profiling(
        self,
        model_config: "ModelConfig",
        seq_len: int,
        mm_registry: "MultiModalRegistry",
311
        is_encoder_data: bool = False,
312
    ) -> DummyData:
313
314
315
316
317
        """
        Create dummy data for profiling the memory usage of a model.

        The model is identified by ``model_config``.

318
319
320
        Note:
            This should be called after
            :meth:`~MultiModalRegistry.init_mm_limits_per_prompt`.
321
322
323
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture
324
        from vllm.multimodal import MultiModalKwargs
325
        from vllm.multimodal.profiling import MultiModalProfiler
326
327
328
329
330
331
332
333
        from vllm.multimodal.utils import cached_get_tokenizer

        if mm_registry.has_processor(model_config):
            tokenizer = cached_get_tokenizer(
                model_config.tokenizer,
                trust_remote_code=model_config.trust_remote_code,
            )
            processor = mm_registry.create_processor(model_config, tokenizer)
334
335
            profiler = MultiModalProfiler(processor)
            dummy_data = profiler.get_dummy_data(seq_len)
336
        else:
337
338
339
340
341
342
343
344
            model_cls, _ = get_model_architecture(model_config)
            if is_encoder_data:
                dummy_factory = self._get_dummy_encoder_data_factory(model_cls)
            else:
                dummy_factory = self._get_dummy_data_factory(model_cls)
            mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
            mm_processor_kwargs = get_allowed_kwarg_only_overrides(
                dummy_factory, overrides=model_config.mm_processor_kwargs)
345

346
347
348
            dummy_data = dummy_factory(InputContext(model_config), seq_len,
                                       _MultiModalCounts(mm_counts),
                                       **mm_processor_kwargs)
349
350

        # Having more tokens is over-conservative but otherwise fine
351
        num_tokens = dummy_data.seq_data.prompt_token_ids
352
353
        if len(num_tokens) < seq_len:
            if is_encoder_data:
354
                logger.warning_once(
355
356
                    f"Expected at least {seq_len} dummy encoder tokens for "
                    f"profiling, but found {len(num_tokens)} tokens instead.")
357
358
359
360
            else:
                raise AssertionError(
                    f"Expected at least {seq_len} dummy tokens for profiling, "
                    f"but found {len(num_tokens)} tokens instead.")
361
362
363

        if (dummy_data.multi_modal_data is not None and
                not isinstance(dummy_data.multi_modal_data, MultiModalKwargs)):
364
            for k, v in dummy_data.multi_modal_data.items():
365
366
367
368
369
370
                num_items = len(v) if isinstance(v, list) else 1
                num_expected = mm_counts[k]
                assert num_items >= num_expected, (
                    f"Expected at least {num_expected} dummy '{k}' instances "
                    f"for profiling, but found {num_items} instances instead.")

371
        return dummy_data
372

373
374
375
    def _default_input_processor(
        self,
        ctx: InputContext,
376
377
        inputs: ProcessorInputs,
    ) -> ProcessorInputs:
378
379
380
381
382
383
384
385
        """The default input processor is a no-op."""
        return inputs

    def register_input_processor(self, processor: InputProcessor):
        """
        Register an input processor to a model class.

        The provided function is invoked on each input to the model. This
386
387
        happens before
        :meth:`~vllm.multimodal.registry.MultiModalRegistry.map_input`.
388
389
390
        """

        def wrapper(model_cls: N) -> N:
391
392
            if self._input_processors_by_model_type.contains(model_cls,
                                                             strict=True):
393
394
395
396
397
398
399
400
401
402
403
                logger.warning(
                    "Model class %s already has input processor "
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

            self._input_processors_by_model_type[model_cls] = processor

            return model_cls

        return wrapper

404
    def _get_model_input_processor(self, model_cls: type[nn.Module]):
405
406
407
        return self._input_processors_by_model_type \
            .get(model_cls, self._default_input_processor)

408
409
410
    def _ensure_mm_kwargs(
        self,
        inputs: SingletonInputs,
411
        mm_processor_kwargs: dict[str, Any],
412
413
414
415
416
417
418
419
420
    ):
        if inputs["type"] == "token":
            # In case the input processor for that model fails to set it
            if "mm_processor_kwargs" not in inputs:
                inputs["mm_processor_kwargs"] = mm_processor_kwargs
        elif inputs["type"] == "multimodal":
            # Be more strict in V2
            assert "mm_kwargs" in inputs
        else:
421
            assert_never(inputs["type"])  # type: ignore[arg-type]
422

423
    def process_input(self, model_config: "ModelConfig",
424
                      inputs: ProcessorInputs) -> ProcessorInputs:
425
426
427
428
429
430
431
432
433
        """
        Apply an input processor to an instance of model inputs.

        The model is identified by ``model_config``.
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
434
        processor = self._get_model_input_processor(model_cls)
435

436
437
438
439
440
        # Handle multimodal processor kwargs with priority:
        #     Inference kwargs -> Init kwargs -> {}
        # If it's empty, it'll fall back to the default kwarg values
        mm_processor_kwargs = resolve_mm_processor_kwargs(
            model_config.mm_processor_kwargs,
441
            inputs.get("mm_processor_kwargs", {}),  # type: ignore
442
443
            processor,
        )
444

445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
        processed_inputs = processor(
            InputContext(model_config),
            inputs,
            **mm_processor_kwargs,
        )

        if is_encoder_decoder_inputs(processed_inputs):
            self._ensure_mm_kwargs(processed_inputs["encoder"],
                                   mm_processor_kwargs)
            self._ensure_mm_kwargs(processed_inputs["decoder"],
                                   mm_processor_kwargs)
        else:
            self._ensure_mm_kwargs(processed_inputs, mm_processor_kwargs)

        return processed_inputs
460
461
462

    def create_input_processor(self, model_config: "ModelConfig"):
        """
463
        Create an input processor (see :meth:`_process_input`) for a
464
465
466
        specific model.
        """
        return functools.partial(self.process_input, model_config)