registry.py 16.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import functools
4
from collections import UserDict
5
from collections.abc import Mapping
6
from dataclasses import dataclass
7
8
from typing import (TYPE_CHECKING, Any, Callable, NamedTuple, Optional,
                    Protocol, Union)
9
10

from torch import nn
11
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
12
from typing_extensions import TypeVar, assert_never
13
14

from vllm.logger import init_logger
15
from vllm.transformers_utils.processor import cached_processor_from_config
16
from vllm.transformers_utils.tokenizer import AnyTokenizer
17
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
18
                        resolve_mm_processor_kwargs)
19

20
from .data import ProcessorInputs, SingletonInputs
21
from .parse import split_enc_dec_inputs
22
23

if TYPE_CHECKING:
24
    from vllm.config import ModelConfig
25
26
    from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict,
                                 MultiModalRegistry)
27
28
29
30
    from vllm.sequence import SequenceData

logger = init_logger(__name__)

31
32
33
_T = TypeVar("_T")
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
34
35


36
37
38
39
40
41
42
43
44
45
@dataclass(frozen=True)
class InputContext:
    """
    Contains information about the model which may be used to
    modify the inputs.
    """

    model_config: "ModelConfig"
    """The configuration of the model."""

46
47
    def get_hf_config(
        self,
48
        typ: Union[type[_C], tuple[type[_C], ...]] = PretrainedConfig,
49
        /,
50
    ) -> _C:
51
52
53
54
55
56
        """
        Get the HuggingFace configuration
        (:class:`transformers.PretrainedConfig`) of the model,
        additionally checking its type.

        Raises:
57
            TypeError: If the configuration is not of the specified type.
58
59
        """
        hf_config = self.model_config.hf_config
60
        if not isinstance(hf_config, typ):
61
            raise TypeError("Invalid type of HuggingFace config. "
62
                            f"Expected type: {typ}, but "
63
64
65
66
                            f"found type: {type(hf_config)}")

        return hf_config

67
    def get_hf_image_processor_config(self) -> dict[str, Any]:
68
69
70
71
72
        """
        Get the HuggingFace image processor configuration of the model.
        """
        return self.model_config.hf_image_processor_config

73
74
75
76
77
78
79
80
81
82
83
84
85
    def get_mm_config(self):
        """
        Get the multimodal config of the model.

        Raises:
            RuntimeError: If the model is not a multimodal model.
        """
        mm_config = self.model_config.multimodal_config
        if mm_config is None:
            raise RuntimeError("Not a multimodal model")

        return mm_config

86
87
    def get_hf_processor(
        self,
88
        typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
89
90
        /,
        **kwargs: object,
91
    ) -> _P:
92
93
94
95
96
97
98
99
        """
        Get the HuggingFace processor
        (:class:`transformers.ProcessorMixin`) of the model,
        additionally checking its type.

        Raises:
            TypeError: If the processor is not of the specified type.
        """
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        return cached_processor_from_config(
            self.model_config,
            processor_cls=typ,
            **kwargs,
        )

    def init_processor(
        self,
        typ: type[_T],
        /,
        **kwargs: object,
    ) -> _T:
        """
        Initialize a HuggingFace-like processor class, merging the
        keyword arguments with those in the model's configuration.
        """
116
117
118
119
120
121
        base_kwargs = self.model_config.mm_processor_kwargs
        if base_kwargs is None:
            base_kwargs = {}

        merged_kwargs = {**base_kwargs, **kwargs}

122
        return typ(**merged_kwargs)
123

124

125
126
127
128
129
@dataclass(frozen=True)
class InputProcessingContext(InputContext):
    tokenizer: AnyTokenizer
    """The tokenizer used to tokenize the inputs."""

130
131
    def get_hf_processor(
        self,
132
        typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
133
134
        /,
        **kwargs: object,
135
    ) -> _P:
136
137
138
139
        return super().get_hf_processor(
            typ,
            tokenizer=self.tokenizer,
            **kwargs,
140
141
        )

142
    def call_hf_processor(
143
144
        self,
        hf_processor: ProcessorMixin,
145
146
        data: Mapping[str, object],
        kwargs: Mapping[str, object] = {},
147
    ) -> BatchFeature:
148
149
150
151
        """
        Call :code:`hf_processor` on the prompt :code:`data`
        (text, image, audio...) with configurable options :code:`kwargs`.
        """
152
153
154
155
156
157
        assert callable(hf_processor)

        base_kwargs = self.model_config.mm_processor_kwargs
        if base_kwargs is None:
            base_kwargs = {}

158
        merged_kwargs = resolve_mm_processor_kwargs(
159
            base_kwargs,
160
            kwargs,
161
            hf_processor,
162
163
            requires_kw_only=False,
            allow_var_kwargs=True,
164
        )
165

166
        try:
167
            return hf_processor(**data, **merged_kwargs, return_tensors="pt")
168
169
170
171
172
173
        except Exception as exc:
            msg = (f"Failed to apply {type(hf_processor).__name__} "
                   f"on data={data} with kwargs={merged_kwargs}")

            raise RuntimeError(msg) from exc

174

175
N = TypeVar("N", bound=type[nn.Module])
176
177


178
179
180
181
182
183
184
185
class DummyData(NamedTuple):
    """Dummy data used for profiling."""

    seq_data: "SequenceData"
    multi_modal_data: Optional["MultiModalDataDict"] = None
    multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None


186
187
188
189
190
191
192
class DummyDataFactory(Protocol):

    def __call__(
        self,
        ctx: InputContext,
        seq_len: int,
        mm_counts: Mapping[str, int],
193
        **mm_processor_kwargs: Any,
194
    ) -> DummyData:
195
196
197
198
199
        """
        Create dummy data to be inputted into the model.

        Note:
            :data:`InputProcessor` is not applied to the dummy data.
200
201
202
203

            The :code:`mm_processor_kwargs` are overrides provided at
            initialization time to values in the config whose values
            may affect the number of tokens per instance.
204
205
206
207
        """
        ...


208
class _MultiModalCounts(UserDict[str, int]):
209
210
211
212
213
214
215
216
217
218
219
220
221
    """
    Wraps `mm_counts` for a more informative error message
    when attempting to access a plugin that does not exist.
    """

    def __getitem__(self, key: str) -> int:
        try:
            return super().__getitem__(key)
        except KeyError as exc:
            msg = (f"There is no multi-modal plugin with the key: {key}. "
                   f"Available keys: {set(self.keys())}")
            raise KeyError(msg) from exc

222

223
InputProcessor = Callable[[InputContext, ProcessorInputs], ProcessorInputs]
224
225
226
227
228
229
230
231
232
233
"""Preprocess the inputs to the model."""


class InputRegistry:
    """
    A registry to dispatch data processing
    according to the target model.
    """

    def __init__(self) -> None:
234
235
236
237
238
239
        self._dummy_factories_by_model_type = \
            ClassRegistry[nn.Module, DummyDataFactory]()
        self._dummy_encoder_factories_by_model_type = \
            ClassRegistry[nn.Module, DummyDataFactory]()
        self._input_processors_by_model_type = \
            ClassRegistry[nn.Module, InputProcessor]()
240
241
242
243
244

    def _default_dummy_data_factory(
        self,
        ctx: InputContext,
        seq_len: int,
245
        mm_counts: Mapping[str, int],
246
    ) -> DummyData:
247
248
249
250
251
252
253
254
255
256
        """
        The default dummy data factory represents the longest possible text
        that can be inputted to the model.

        Note:
            :data:`InputProcessor` is not applied to the dummy data.
        """
        # Avoid circular import
        from vllm.sequence import SequenceData

257
        return DummyData(SequenceData.from_prompt_token_counts((0, seq_len)))
258
259
260
261
262
263
264
265
266
267
268

    def register_dummy_data(self, factory: DummyDataFactory):
        """
        Register a dummy data factory to a model class.

        During memory profiling, the provided function is invoked to create
        dummy data to be inputted into the model. The resulting memory usage
        should be an upper bound of what the model would use at inference time.
        """

        def wrapper(model_cls: N) -> N:
269
270
            if self._dummy_factories_by_model_type.contains(model_cls,
                                                            strict=True):
271
272
273
274
275
276
277
278
279
280
281
                logger.warning(
                    "Model class %s already has dummy data "
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

            self._dummy_factories_by_model_type[model_cls] = factory

            return model_cls

        return wrapper

282
    def _get_dummy_data_factory(self, model_cls: type[nn.Module]):
283
284
285
        return self._dummy_factories_by_model_type \
            .get(model_cls, self._default_dummy_data_factory)

286
287
288
289
290
291
292
293
    def register_dummy_encoder_data(self, factory: DummyDataFactory):
        """
        Register a dummy encoder data factory to a model class

        This is similar to :meth:`~register_dummy_data`, but for encoder input.
        """

        def wrapper(model_cls: N) -> N:
294
295
            if self._dummy_encoder_factories_by_model_type.contains(
                    model_cls, strict=True):
296
297
298
299
300
301
302
303
304
305
306
                logger.warning(
                    "Model class %s already has dummy encoder data "
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

            self._dummy_encoder_factories_by_model_type[model_cls] = factory

            return model_cls

        return wrapper

307
    def _get_dummy_encoder_data_factory(self, model_cls: type[nn.Module]):
308
309
        return self._dummy_encoder_factories_by_model_type \
            .get(model_cls, self._default_dummy_data_factory)
310

311
312
313
314
315
    def dummy_data_for_profiling(
        self,
        model_config: "ModelConfig",
        seq_len: int,
        mm_registry: "MultiModalRegistry",
316
        is_encoder_data: bool = False,
317
    ) -> DummyData:
318
319
320
321
322
        """
        Create dummy data for profiling the memory usage of a model.

        The model is identified by ``model_config``.

323
324
325
        Note:
            This should be called after
            :meth:`~MultiModalRegistry.init_mm_limits_per_prompt`.
326
327
328
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture
329
        from vllm.multimodal import MultiModalKwargs
330
        from vllm.multimodal.profiling import MultiModalProfiler
331
        from vllm.sequence import SequenceData
332
333

        if mm_registry.has_processor(model_config):
334
335
            processor = mm_registry.create_processor(model_config,
                                                     disable_cache=True)
336
            profiler = MultiModalProfiler(processor)
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351

            dummy_data_v1 = (profiler.get_encoder_dummy_data(seq_len)
                             if is_encoder_data else
                             profiler.get_decoder_dummy_data(seq_len))
            _seq_data = SequenceData.from_seqs(
                dummy_data_v1.prompt_token_ids)  # type: ignore[attr-defined]

            dummy_data = DummyData(
                seq_data=_seq_data,
                multi_modal_data=getattr(dummy_data_v1, "multi_modal_data",
                                         None),
                multi_modal_placeholders=getattr(dummy_data_v1,
                                                 "multi_modal_placeholders",
                                                 None),
            )
352
        else:
353
354
355
356
357
358
359
            model_cls, _ = get_model_architecture(model_config)
            if is_encoder_data:
                dummy_factory = self._get_dummy_encoder_data_factory(model_cls)
            else:
                dummy_factory = self._get_dummy_data_factory(model_cls)
            mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
            mm_processor_kwargs = get_allowed_kwarg_only_overrides(
360
361
362
363
364
                dummy_factory,
                overrides=model_config.mm_processor_kwargs,
                requires_kw_only=False,
                allow_var_kwargs=True,
            )
365

366
367
368
            dummy_data = dummy_factory(InputContext(model_config), seq_len,
                                       _MultiModalCounts(mm_counts),
                                       **mm_processor_kwargs)
369
370

        # Having more tokens is over-conservative but otherwise fine
371
        num_tokens = dummy_data.seq_data.prompt_token_ids
372
373
        if len(num_tokens) < seq_len:
            if is_encoder_data:
374
                logger.warning_once(
375
376
                    f"Expected at least {seq_len} dummy encoder tokens for "
                    f"profiling, but found {len(num_tokens)} tokens instead.")
377
378
379
380
            else:
                raise AssertionError(
                    f"Expected at least {seq_len} dummy tokens for profiling, "
                    f"but found {len(num_tokens)} tokens instead.")
381
382
383

        if (dummy_data.multi_modal_data is not None and
                not isinstance(dummy_data.multi_modal_data, MultiModalKwargs)):
384
            for k, v in dummy_data.multi_modal_data.items():
385
386
387
388
389
390
                num_items = len(v) if isinstance(v, list) else 1
                num_expected = mm_counts[k]
                assert num_items >= num_expected, (
                    f"Expected at least {num_expected} dummy '{k}' instances "
                    f"for profiling, but found {num_items} instances instead.")

391
        return dummy_data
392

393
394
395
    def _default_input_processor(
        self,
        ctx: InputContext,
396
        inputs: ProcessorInputs,
397
        **kwargs: object,
398
    ) -> ProcessorInputs:
399
400
401
402
403
404
405
406
        """The default input processor is a no-op."""
        return inputs

    def register_input_processor(self, processor: InputProcessor):
        """
        Register an input processor to a model class.

        The provided function is invoked on each input to the model. This
407
408
        happens before
        :meth:`~vllm.multimodal.registry.MultiModalRegistry.map_input`.
409
410
411
        """

        def wrapper(model_cls: N) -> N:
412
413
            if self._input_processors_by_model_type.contains(model_cls,
                                                             strict=True):
414
415
416
417
418
419
420
421
422
423
424
                logger.warning(
                    "Model class %s already has input processor "
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

            self._input_processors_by_model_type[model_cls] = processor

            return model_cls

        return wrapper

425
    def _get_model_input_processor(self, model_cls: type[nn.Module]):
426
427
428
        return self._input_processors_by_model_type \
            .get(model_cls, self._default_input_processor)

429
430
431
    def _ensure_mm_kwargs(
        self,
        inputs: SingletonInputs,
432
        mm_processor_kwargs: dict[str, Any],
433
434
435
436
437
438
439
440
441
    ):
        if inputs["type"] == "token":
            # In case the input processor for that model fails to set it
            if "mm_processor_kwargs" not in inputs:
                inputs["mm_processor_kwargs"] = mm_processor_kwargs
        elif inputs["type"] == "multimodal":
            # Be more strict in V2
            assert "mm_kwargs" in inputs
        else:
442
            assert_never(inputs["type"])  # type: ignore[arg-type]
443

444
    def process_input(self, model_config: "ModelConfig",
445
                      inputs: ProcessorInputs) -> ProcessorInputs:
446
447
448
449
450
451
452
453
454
        """
        Apply an input processor to an instance of model inputs.

        The model is identified by ``model_config``.
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
455
        processor = self._get_model_input_processor(model_cls)
456

457
458
459
460
461
        # Handle multimodal processor kwargs with priority:
        #     Inference kwargs -> Init kwargs -> {}
        # If it's empty, it'll fall back to the default kwarg values
        mm_processor_kwargs = resolve_mm_processor_kwargs(
            model_config.mm_processor_kwargs,
462
            inputs.get("mm_processor_kwargs", {}),  # type: ignore
463
            processor,
464
465
            requires_kw_only=False,
            allow_var_kwargs=True,
466
        )
467

468
469
470
471
472
473
        processed_inputs = processor(
            InputContext(model_config),
            inputs,
            **mm_processor_kwargs,
        )

474
475
476
477
478
        encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
        if encoder_inputs is not None:
            self._ensure_mm_kwargs(encoder_inputs, mm_processor_kwargs)
        if decoder_inputs is not None:
            self._ensure_mm_kwargs(decoder_inputs, mm_processor_kwargs)
479
480

        return processed_inputs
481
482
483

    def create_input_processor(self, model_config: "ModelConfig"):
        """
484
        Create an input processor (see :meth:`_process_input`) for a
485
486
487
        specific model.
        """
        return functools.partial(self.process_input, model_config)