registry.py 16.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import functools
4
from collections import UserDict
5
from collections.abc import Mapping
6
from dataclasses import dataclass
7
8
from typing import (TYPE_CHECKING, Any, Callable, NamedTuple, Optional,
                    Protocol, Union)
9
10

from torch import nn
11
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
12
from typing_extensions import TypeVar, assert_never
13
14

from vllm.logger import init_logger
15
16
17
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
                                               cached_tokenizer_from_config)
18
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
19
                        resolve_mm_processor_kwargs)
20

21
22
from .data import ProcessorInputs, SingletonInputs
from .parse import is_encoder_decoder_inputs
23
24

if TYPE_CHECKING:
25
    from vllm.config import ModelConfig
26
27
    from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict,
                                 MultiModalRegistry)
28
29
30
31
    from vllm.sequence import SequenceData

logger = init_logger(__name__)

32
33
34
_T = TypeVar("_T")
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
35
36


37
38
39
40
41
42
43
44
45
46
@dataclass(frozen=True)
class InputContext:
    """
    Contains information about the model which may be used to
    modify the inputs.
    """

    model_config: "ModelConfig"
    """The configuration of the model."""

47
48
    def get_hf_config(
        self,
49
        typ: Union[type[_C], tuple[type[_C], ...]] = PretrainedConfig,
50
        /,
51
    ) -> _C:
52
53
54
55
56
57
        """
        Get the HuggingFace configuration
        (:class:`transformers.PretrainedConfig`) of the model,
        additionally checking its type.

        Raises:
58
            TypeError: If the configuration is not of the specified type.
59
60
        """
        hf_config = self.model_config.hf_config
61
        if not isinstance(hf_config, typ):
62
            raise TypeError("Invalid type of HuggingFace config. "
63
                            f"Expected type: {typ}, but "
64
65
66
67
                            f"found type: {type(hf_config)}")

        return hf_config

68
    def get_hf_image_processor_config(self) -> dict[str, Any]:
69
70
71
72
73
        """
        Get the HuggingFace image processor configuration of the model.
        """
        return self.model_config.hf_image_processor_config

74
75
76
77
78
79
80
81
82
83
84
85
86
    def get_mm_config(self):
        """
        Get the multimodal config of the model.

        Raises:
            RuntimeError: If the model is not a multimodal model.
        """
        mm_config = self.model_config.multimodal_config
        if mm_config is None:
            raise RuntimeError("Not a multimodal model")

        return mm_config

87
88
    def get_hf_processor(
        self,
89
        typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
90
91
        /,
        **kwargs: object,
92
    ) -> _P:
93
94
95
96
97
98
99
100
        """
        Get the HuggingFace processor
        (:class:`transformers.ProcessorMixin`) of the model,
        additionally checking its type.

        Raises:
            TypeError: If the processor is not of the specified type.
        """
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        return cached_processor_from_config(
            self.model_config,
            processor_cls=typ,
            **kwargs,
        )

    def init_processor(
        self,
        typ: type[_T],
        /,
        **kwargs: object,
    ) -> _T:
        """
        Initialize a HuggingFace-like processor class, merging the
        keyword arguments with those in the model's configuration.
        """
117
118
119
120
121
122
        base_kwargs = self.model_config.mm_processor_kwargs
        if base_kwargs is None:
            base_kwargs = {}

        merged_kwargs = {**base_kwargs, **kwargs}

123
        return typ(**merged_kwargs)
124

125

126
127
128
129
130
@dataclass(frozen=True)
class InputProcessingContext(InputContext):
    tokenizer: AnyTokenizer
    """The tokenizer used to tokenize the inputs."""

131
132
    def get_hf_processor(
        self,
133
        typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
134
135
        /,
        **kwargs: object,
136
    ) -> _P:
137
138
139
140
        return super().get_hf_processor(
            typ,
            tokenizer=self.tokenizer,
            **kwargs,
141
142
        )

143
    def call_hf_processor(
144
145
        self,
        hf_processor: ProcessorMixin,
146
147
        data: Mapping[str, object],
        kwargs: Mapping[str, object] = {},
148
    ) -> BatchFeature:
149
150
151
152
        """
        Call :code:`hf_processor` on the prompt :code:`data`
        (text, image, audio...) with configurable options :code:`kwargs`.
        """
153
154
155
156
157
158
        assert callable(hf_processor)

        base_kwargs = self.model_config.mm_processor_kwargs
        if base_kwargs is None:
            base_kwargs = {}

159
        merged_kwargs = resolve_mm_processor_kwargs(
160
            base_kwargs,
161
            kwargs,
162
            hf_processor,
163
164
            requires_kw_only=False,
            allow_var_kwargs=True,
165
        )
166

167
        try:
168
            return hf_processor(**data, **merged_kwargs, return_tensors="pt")
169
170
171
172
173
174
        except Exception as exc:
            msg = (f"Failed to apply {type(hf_processor).__name__} "
                   f"on data={data} with kwargs={merged_kwargs}")

            raise RuntimeError(msg) from exc

175

176
N = TypeVar("N", bound=type[nn.Module])
177
178


179
180
181
182
183
184
185
186
class DummyData(NamedTuple):
    """Dummy data used for profiling."""

    seq_data: "SequenceData"
    multi_modal_data: Optional["MultiModalDataDict"] = None
    multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None


187
188
189
190
191
192
193
class DummyDataFactory(Protocol):

    def __call__(
        self,
        ctx: InputContext,
        seq_len: int,
        mm_counts: Mapping[str, int],
194
        **mm_processor_kwargs: Any,
195
    ) -> DummyData:
196
197
198
199
200
        """
        Create dummy data to be inputted into the model.

        Note:
            :data:`InputProcessor` is not applied to the dummy data.
201
202
203
204

            The :code:`mm_processor_kwargs` are overrides provided at
            initialization time to values in the config whose values
            may affect the number of tokens per instance.
205
206
207
208
        """
        ...


209
class _MultiModalCounts(UserDict[str, int]):
210
211
212
213
214
215
216
217
218
219
220
221
222
    """
    Wraps `mm_counts` for a more informative error message
    when attempting to access a plugin that does not exist.
    """

    def __getitem__(self, key: str) -> int:
        try:
            return super().__getitem__(key)
        except KeyError as exc:
            msg = (f"There is no multi-modal plugin with the key: {key}. "
                   f"Available keys: {set(self.keys())}")
            raise KeyError(msg) from exc

223

224
InputProcessor = Callable[[InputContext, ProcessorInputs], ProcessorInputs]
225
226
227
228
229
230
231
232
233
234
"""Preprocess the inputs to the model."""


class InputRegistry:
    """
    A registry to dispatch data processing
    according to the target model.
    """

    def __init__(self) -> None:
235
236
237
238
239
240
        self._dummy_factories_by_model_type = \
            ClassRegistry[nn.Module, DummyDataFactory]()
        self._dummy_encoder_factories_by_model_type = \
            ClassRegistry[nn.Module, DummyDataFactory]()
        self._input_processors_by_model_type = \
            ClassRegistry[nn.Module, InputProcessor]()
241
242
243
244
245

    def _default_dummy_data_factory(
        self,
        ctx: InputContext,
        seq_len: int,
246
        mm_counts: Mapping[str, int],
247
    ) -> DummyData:
248
249
250
251
252
253
254
255
256
257
        """
        The default dummy data factory represents the longest possible text
        that can be inputted to the model.

        Note:
            :data:`InputProcessor` is not applied to the dummy data.
        """
        # Avoid circular import
        from vllm.sequence import SequenceData

258
        return DummyData(SequenceData.from_prompt_token_counts((0, seq_len)))
259
260
261
262
263
264
265
266
267
268
269

    def register_dummy_data(self, factory: DummyDataFactory):
        """
        Register a dummy data factory to a model class.

        During memory profiling, the provided function is invoked to create
        dummy data to be inputted into the model. The resulting memory usage
        should be an upper bound of what the model would use at inference time.
        """

        def wrapper(model_cls: N) -> N:
270
271
            if self._dummy_factories_by_model_type.contains(model_cls,
                                                            strict=True):
272
273
274
275
276
277
278
279
280
281
282
                logger.warning(
                    "Model class %s already has dummy data "
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

            self._dummy_factories_by_model_type[model_cls] = factory

            return model_cls

        return wrapper

283
    def _get_dummy_data_factory(self, model_cls: type[nn.Module]):
284
285
286
        return self._dummy_factories_by_model_type \
            .get(model_cls, self._default_dummy_data_factory)

287
288
289
290
291
292
293
294
    def register_dummy_encoder_data(self, factory: DummyDataFactory):
        """
        Register a dummy encoder data factory to a model class

        This is similar to :meth:`~register_dummy_data`, but for encoder input.
        """

        def wrapper(model_cls: N) -> N:
295
296
            if self._dummy_encoder_factories_by_model_type.contains(
                    model_cls, strict=True):
297
298
299
300
301
302
303
304
305
306
307
                logger.warning(
                    "Model class %s already has dummy encoder data "
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

            self._dummy_encoder_factories_by_model_type[model_cls] = factory

            return model_cls

        return wrapper

308
    def _get_dummy_encoder_data_factory(self, model_cls: type[nn.Module]):
309
310
        return self._dummy_encoder_factories_by_model_type \
            .get(model_cls, self._default_dummy_data_factory)
311

312
313
314
315
316
    def dummy_data_for_profiling(
        self,
        model_config: "ModelConfig",
        seq_len: int,
        mm_registry: "MultiModalRegistry",
317
        is_encoder_data: bool = False,
318
    ) -> DummyData:
319
320
321
322
323
        """
        Create dummy data for profiling the memory usage of a model.

        The model is identified by ``model_config``.

324
325
326
        Note:
            This should be called after
            :meth:`~MultiModalRegistry.init_mm_limits_per_prompt`.
327
328
329
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture
330
        from vllm.multimodal import MultiModalKwargs
331
        from vllm.multimodal.profiling import MultiModalProfiler
332
333

        if mm_registry.has_processor(model_config):
334
            tokenizer = cached_tokenizer_from_config(model_config)
335
336
337
            processor = mm_registry.create_processor(model_config,
                                                     tokenizer,
                                                     disable_cache=True)
338
            profiler = MultiModalProfiler(processor)
339
340
341
342
            dummy_data_factory = (profiler.get_encoder_dummy_data
                                  if is_encoder_data else
                                  profiler.get_decoder_dummy_data)
            dummy_data = dummy_data_factory(seq_len)
343
        else:
344
345
346
347
348
349
350
            model_cls, _ = get_model_architecture(model_config)
            if is_encoder_data:
                dummy_factory = self._get_dummy_encoder_data_factory(model_cls)
            else:
                dummy_factory = self._get_dummy_data_factory(model_cls)
            mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
            mm_processor_kwargs = get_allowed_kwarg_only_overrides(
351
352
353
354
355
                dummy_factory,
                overrides=model_config.mm_processor_kwargs,
                requires_kw_only=False,
                allow_var_kwargs=True,
            )
356

357
358
359
            dummy_data = dummy_factory(InputContext(model_config), seq_len,
                                       _MultiModalCounts(mm_counts),
                                       **mm_processor_kwargs)
360
361

        # Having more tokens is over-conservative but otherwise fine
362
        num_tokens = dummy_data.seq_data.prompt_token_ids
363
364
        if len(num_tokens) < seq_len:
            if is_encoder_data:
365
                logger.warning_once(
366
367
                    f"Expected at least {seq_len} dummy encoder tokens for "
                    f"profiling, but found {len(num_tokens)} tokens instead.")
368
369
370
371
            else:
                raise AssertionError(
                    f"Expected at least {seq_len} dummy tokens for profiling, "
                    f"but found {len(num_tokens)} tokens instead.")
372
373
374

        if (dummy_data.multi_modal_data is not None and
                not isinstance(dummy_data.multi_modal_data, MultiModalKwargs)):
375
            for k, v in dummy_data.multi_modal_data.items():
376
377
378
379
380
381
                num_items = len(v) if isinstance(v, list) else 1
                num_expected = mm_counts[k]
                assert num_items >= num_expected, (
                    f"Expected at least {num_expected} dummy '{k}' instances "
                    f"for profiling, but found {num_items} instances instead.")

382
        return dummy_data
383

384
385
386
    def _default_input_processor(
        self,
        ctx: InputContext,
387
        inputs: ProcessorInputs,
388
        **kwargs: object,
389
    ) -> ProcessorInputs:
390
391
392
393
394
395
396
397
        """The default input processor is a no-op."""
        return inputs

    def register_input_processor(self, processor: InputProcessor):
        """
        Register an input processor to a model class.

        The provided function is invoked on each input to the model. This
398
399
        happens before
        :meth:`~vllm.multimodal.registry.MultiModalRegistry.map_input`.
400
401
402
        """

        def wrapper(model_cls: N) -> N:
403
404
            if self._input_processors_by_model_type.contains(model_cls,
                                                             strict=True):
405
406
407
408
409
410
411
412
413
414
415
                logger.warning(
                    "Model class %s already has input processor "
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

            self._input_processors_by_model_type[model_cls] = processor

            return model_cls

        return wrapper

416
    def _get_model_input_processor(self, model_cls: type[nn.Module]):
417
418
419
        return self._input_processors_by_model_type \
            .get(model_cls, self._default_input_processor)

420
421
422
    def _ensure_mm_kwargs(
        self,
        inputs: SingletonInputs,
423
        mm_processor_kwargs: dict[str, Any],
424
425
426
427
428
429
430
431
432
    ):
        if inputs["type"] == "token":
            # In case the input processor for that model fails to set it
            if "mm_processor_kwargs" not in inputs:
                inputs["mm_processor_kwargs"] = mm_processor_kwargs
        elif inputs["type"] == "multimodal":
            # Be more strict in V2
            assert "mm_kwargs" in inputs
        else:
433
            assert_never(inputs["type"])  # type: ignore[arg-type]
434

435
    def process_input(self, model_config: "ModelConfig",
436
                      inputs: ProcessorInputs) -> ProcessorInputs:
437
438
439
440
441
442
443
444
445
        """
        Apply an input processor to an instance of model inputs.

        The model is identified by ``model_config``.
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
446
        processor = self._get_model_input_processor(model_cls)
447

448
449
450
451
452
        # Handle multimodal processor kwargs with priority:
        #     Inference kwargs -> Init kwargs -> {}
        # If it's empty, it'll fall back to the default kwarg values
        mm_processor_kwargs = resolve_mm_processor_kwargs(
            model_config.mm_processor_kwargs,
453
            inputs.get("mm_processor_kwargs", {}),  # type: ignore
454
            processor,
455
456
            requires_kw_only=False,
            allow_var_kwargs=True,
457
        )
458

459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
        processed_inputs = processor(
            InputContext(model_config),
            inputs,
            **mm_processor_kwargs,
        )

        if is_encoder_decoder_inputs(processed_inputs):
            self._ensure_mm_kwargs(processed_inputs["encoder"],
                                   mm_processor_kwargs)
            self._ensure_mm_kwargs(processed_inputs["decoder"],
                                   mm_processor_kwargs)
        else:
            self._ensure_mm_kwargs(processed_inputs, mm_processor_kwargs)

        return processed_inputs
474
475
476

    def create_input_processor(self, model_config: "ModelConfig"):
        """
477
        Create an input processor (see :meth:`_process_input`) for a
478
479
480
        specific model.
        """
        return functools.partial(self.process_input, model_config)