registry.py 16.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import functools
4
from collections import UserDict
5
from dataclasses import dataclass
6
7
from typing import (TYPE_CHECKING, Any, Callable, Mapping, NamedTuple,
                    Optional, Protocol, Union)
8
9

from torch import nn
10
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
11
from typing_extensions import TypeVar, assert_never
12
13

from vllm.logger import init_logger
14
15
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer
16
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
17
                        resolve_mm_processor_kwargs)
18

19
20
from .data import ProcessorInputs, SingletonInputs
from .parse import is_encoder_decoder_inputs
21
22

if TYPE_CHECKING:
23
    from vllm.config import ModelConfig
24
25
    from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict,
                                 MultiModalRegistry)
26
27
28
29
    from vllm.sequence import SequenceData

logger = init_logger(__name__)

30
C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
31
P = TypeVar("P", bound=ProcessorMixin, default=ProcessorMixin)
32
33


34
35
36
37
38
39
40
41
42
43
44
class HashableDict(dict):
    """
    A dictionary that can be hashed by lru_cache.
    """

    # NOTE: pythonic dict is not hashable,
    # we override on it directly for simplicity
    def __hash__(self) -> int:  # type: ignore[override]
        return hash(frozenset(self.items()))


45
46
47
48
49
50
51
52
53
54
@dataclass(frozen=True)
class InputContext:
    """
    Contains information about the model which may be used to
    modify the inputs.
    """

    model_config: "ModelConfig"
    """The configuration of the model."""

55
56
57
58
59
    def get_hf_config(
        self,
        typ: Union[type[C], tuple[type[C], ...]] = PretrainedConfig,
        /,
    ) -> C:
60
61
62
63
64
65
        """
        Get the HuggingFace configuration
        (:class:`transformers.PretrainedConfig`) of the model,
        additionally checking its type.

        Raises:
66
            TypeError: If the configuration is not of the specified type.
67
68
        """
        hf_config = self.model_config.hf_config
69
        if not isinstance(hf_config, typ):
70
            raise TypeError("Invalid type of HuggingFace config. "
71
                            f"Expected type: {typ}, but "
72
73
74
75
                            f"found type: {type(hf_config)}")

        return hf_config

76
    def get_hf_image_processor_config(self) -> dict[str, Any]:
77
78
79
80
81
        """
        Get the HuggingFace image processor configuration of the model.
        """
        return self.model_config.hf_image_processor_config

82
83
84
85
86
87
88
89
90
91
92
93
94
    def get_mm_config(self):
        """
        Get the multimodal config of the model.

        Raises:
            RuntimeError: If the model is not a multimodal model.
        """
        mm_config = self.model_config.multimodal_config
        if mm_config is None:
            raise RuntimeError("Not a multimodal model")

        return mm_config

95
96
97
98
99
100
101
102
103
104
105
106
107
108
    def get_hf_processor(
        self,
        typ: Union[type[P], tuple[type[P], ...]] = ProcessorMixin,
        /,
        **kwargs: object,
    ) -> P:
        """
        Get the HuggingFace processor
        (:class:`transformers.ProcessorMixin`) of the model,
        additionally checking its type.

        Raises:
            TypeError: If the processor is not of the specified type.
        """
109
110
111
112
113
114
        base_kwargs = self.model_config.mm_processor_kwargs
        if base_kwargs is None:
            base_kwargs = {}

        merged_kwargs = {**base_kwargs, **kwargs}

115
116
117
        if isinstance(typ, type):
            merged_kwargs["processor_cls"] = typ

118
119
120
121
122
123
124
        # NOTE: Pythonic dict is not hashable and will raise unhashable type
        # error when calling `cached_get_processor`, therefore we need to
        # wrap it to a hashable dict.
        for key, value in merged_kwargs.items():
            if isinstance(value, dict):
                merged_kwargs[key] = HashableDict(value)

125
        hf_processor = cached_get_processor(
126
127
128
129
            self.model_config.model,
            trust_remote_code=self.model_config.trust_remote_code,
            **merged_kwargs,
        )
130
131
132
133
134
135
        if not isinstance(hf_processor, typ):
            raise TypeError("Invalid type of HuggingFace processor. "
                            f"Expected type: {typ}, but "
                            f"found type: {type(hf_processor)}")

        return hf_processor
136

137

138
139
140
141
142
@dataclass(frozen=True)
class InputProcessingContext(InputContext):
    tokenizer: AnyTokenizer
    """The tokenizer used to tokenize the inputs."""

143
144
145
146
147
148
149
150
151
152
    def get_hf_processor(
        self,
        typ: Union[type[P], tuple[type[P], ...]] = ProcessorMixin,
        /,
        **kwargs: object,
    ) -> P:
        return super().get_hf_processor(
            typ,
            tokenizer=self.tokenizer,
            **kwargs,
153
154
        )

155
    def call_hf_processor(
156
157
        self,
        hf_processor: ProcessorMixin,
158
159
        data: Mapping[str, object],
        kwargs: Mapping[str, object] = {},
160
    ) -> BatchFeature:
161
162
163
164
        """
        Call :code:`hf_processor` on the prompt :code:`data`
        (text, image, audio...) with configurable options :code:`kwargs`.
        """
165
166
167
168
169
170
        assert callable(hf_processor)

        base_kwargs = self.model_config.mm_processor_kwargs
        if base_kwargs is None:
            base_kwargs = {}

171
        merged_kwargs = resolve_mm_processor_kwargs(
172
            base_kwargs,
173
            kwargs,
174
            hf_processor,
175
176
            requires_kw_only=False,
            allow_var_kwargs=True,
177
        )
178

179
        try:
180
            return hf_processor(**data, **merged_kwargs, return_tensors="pt")
181
182
183
184
185
186
        except Exception as exc:
            msg = (f"Failed to apply {type(hf_processor).__name__} "
                   f"on data={data} with kwargs={merged_kwargs}")

            raise RuntimeError(msg) from exc

187

188
N = TypeVar("N", bound=type[nn.Module])
189
190


191
192
193
194
195
196
197
198
class DummyData(NamedTuple):
    """Dummy data used for profiling."""

    seq_data: "SequenceData"
    multi_modal_data: Optional["MultiModalDataDict"] = None
    multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None


199
200
201
202
203
204
205
class DummyDataFactory(Protocol):

    def __call__(
        self,
        ctx: InputContext,
        seq_len: int,
        mm_counts: Mapping[str, int],
206
        **mm_processor_kwargs: Any,
207
    ) -> DummyData:
208
209
210
211
212
        """
        Create dummy data to be inputted into the model.

        Note:
            :data:`InputProcessor` is not applied to the dummy data.
213
214
215
216

            The :code:`mm_processor_kwargs` are overrides provided at
            initialization time to values in the config whose values
            may affect the number of tokens per instance.
217
218
219
220
        """
        ...


221
class _MultiModalCounts(UserDict[str, int]):
222
223
224
225
226
227
228
229
230
231
232
233
234
    """
    Wraps `mm_counts` for a more informative error message
    when attempting to access a plugin that does not exist.
    """

    def __getitem__(self, key: str) -> int:
        try:
            return super().__getitem__(key)
        except KeyError as exc:
            msg = (f"There is no multi-modal plugin with the key: {key}. "
                   f"Available keys: {set(self.keys())}")
            raise KeyError(msg) from exc

235

236
InputProcessor = Callable[[InputContext, ProcessorInputs], ProcessorInputs]
237
238
239
240
241
242
243
244
245
246
"""Preprocess the inputs to the model."""


class InputRegistry:
    """
    A registry to dispatch data processing
    according to the target model.
    """

    def __init__(self) -> None:
247
248
249
250
251
252
        self._dummy_factories_by_model_type = \
            ClassRegistry[nn.Module, DummyDataFactory]()
        self._dummy_encoder_factories_by_model_type = \
            ClassRegistry[nn.Module, DummyDataFactory]()
        self._input_processors_by_model_type = \
            ClassRegistry[nn.Module, InputProcessor]()
253
254
255
256
257

    def _default_dummy_data_factory(
        self,
        ctx: InputContext,
        seq_len: int,
258
        mm_counts: Mapping[str, int],
259
    ) -> DummyData:
260
261
262
263
264
265
266
267
268
269
        """
        The default dummy data factory represents the longest possible text
        that can be inputted to the model.

        Note:
            :data:`InputProcessor` is not applied to the dummy data.
        """
        # Avoid circular import
        from vllm.sequence import SequenceData

270
        return DummyData(SequenceData.from_prompt_token_counts((0, seq_len)))
271
272
273
274
275
276
277
278
279
280
281

    def register_dummy_data(self, factory: DummyDataFactory):
        """
        Register a dummy data factory to a model class.

        During memory profiling, the provided function is invoked to create
        dummy data to be inputted into the model. The resulting memory usage
        should be an upper bound of what the model would use at inference time.
        """

        def wrapper(model_cls: N) -> N:
282
283
            if self._dummy_factories_by_model_type.contains(model_cls,
                                                            strict=True):
284
285
286
287
288
289
290
291
292
293
294
                logger.warning(
                    "Model class %s already has dummy data "
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

            self._dummy_factories_by_model_type[model_cls] = factory

            return model_cls

        return wrapper

295
    def _get_dummy_data_factory(self, model_cls: type[nn.Module]):
296
297
298
        return self._dummy_factories_by_model_type \
            .get(model_cls, self._default_dummy_data_factory)

299
300
301
302
303
304
305
306
    def register_dummy_encoder_data(self, factory: DummyDataFactory):
        """
        Register a dummy encoder data factory to a model class

        This is similar to :meth:`~register_dummy_data`, but for encoder input.
        """

        def wrapper(model_cls: N) -> N:
307
308
            if self._dummy_encoder_factories_by_model_type.contains(
                    model_cls, strict=True):
309
310
311
312
313
314
315
316
317
318
319
                logger.warning(
                    "Model class %s already has dummy encoder data "
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

            self._dummy_encoder_factories_by_model_type[model_cls] = factory

            return model_cls

        return wrapper

320
    def _get_dummy_encoder_data_factory(self, model_cls: type[nn.Module]):
321
322
        return self._dummy_encoder_factories_by_model_type \
            .get(model_cls, self._default_dummy_data_factory)
323

324
325
326
327
328
    def dummy_data_for_profiling(
        self,
        model_config: "ModelConfig",
        seq_len: int,
        mm_registry: "MultiModalRegistry",
329
        is_encoder_data: bool = False,
330
    ) -> DummyData:
331
332
333
334
335
        """
        Create dummy data for profiling the memory usage of a model.

        The model is identified by ``model_config``.

336
337
338
        Note:
            This should be called after
            :meth:`~MultiModalRegistry.init_mm_limits_per_prompt`.
339
340
341
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture
342
        from vllm.multimodal import MultiModalKwargs
343
        from vllm.multimodal.profiling import MultiModalProfiler
344
345
346
347
348
349
350
351
        from vllm.multimodal.utils import cached_get_tokenizer

        if mm_registry.has_processor(model_config):
            tokenizer = cached_get_tokenizer(
                model_config.tokenizer,
                trust_remote_code=model_config.trust_remote_code,
            )
            processor = mm_registry.create_processor(model_config, tokenizer)
352
353
            profiler = MultiModalProfiler(processor)
            dummy_data = profiler.get_dummy_data(seq_len)
354
        else:
355
356
357
358
359
360
361
362
            model_cls, _ = get_model_architecture(model_config)
            if is_encoder_data:
                dummy_factory = self._get_dummy_encoder_data_factory(model_cls)
            else:
                dummy_factory = self._get_dummy_data_factory(model_cls)
            mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
            mm_processor_kwargs = get_allowed_kwarg_only_overrides(
                dummy_factory, overrides=model_config.mm_processor_kwargs)
363

364
365
366
            dummy_data = dummy_factory(InputContext(model_config), seq_len,
                                       _MultiModalCounts(mm_counts),
                                       **mm_processor_kwargs)
367
368

        # Having more tokens is over-conservative but otherwise fine
369
        num_tokens = dummy_data.seq_data.prompt_token_ids
370
371
        if len(num_tokens) < seq_len:
            if is_encoder_data:
372
                logger.warning_once(
373
374
                    f"Expected at least {seq_len} dummy encoder tokens for "
                    f"profiling, but found {len(num_tokens)} tokens instead.")
375
376
377
378
            else:
                raise AssertionError(
                    f"Expected at least {seq_len} dummy tokens for profiling, "
                    f"but found {len(num_tokens)} tokens instead.")
379
380
381

        if (dummy_data.multi_modal_data is not None and
                not isinstance(dummy_data.multi_modal_data, MultiModalKwargs)):
382
            for k, v in dummy_data.multi_modal_data.items():
383
384
385
386
387
388
                num_items = len(v) if isinstance(v, list) else 1
                num_expected = mm_counts[k]
                assert num_items >= num_expected, (
                    f"Expected at least {num_expected} dummy '{k}' instances "
                    f"for profiling, but found {num_items} instances instead.")

389
        return dummy_data
390

391
392
393
    def _default_input_processor(
        self,
        ctx: InputContext,
394
395
        inputs: ProcessorInputs,
    ) -> ProcessorInputs:
396
397
398
399
400
401
402
403
        """The default input processor is a no-op."""
        return inputs

    def register_input_processor(self, processor: InputProcessor):
        """
        Register an input processor to a model class.

        The provided function is invoked on each input to the model. This
404
405
        happens before
        :meth:`~vllm.multimodal.registry.MultiModalRegistry.map_input`.
406
407
408
        """

        def wrapper(model_cls: N) -> N:
409
410
            if self._input_processors_by_model_type.contains(model_cls,
                                                             strict=True):
411
412
413
414
415
416
417
418
419
420
421
                logger.warning(
                    "Model class %s already has input processor "
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

            self._input_processors_by_model_type[model_cls] = processor

            return model_cls

        return wrapper

422
    def _get_model_input_processor(self, model_cls: type[nn.Module]):
423
424
425
        return self._input_processors_by_model_type \
            .get(model_cls, self._default_input_processor)

426
427
428
    def _ensure_mm_kwargs(
        self,
        inputs: SingletonInputs,
429
        mm_processor_kwargs: dict[str, Any],
430
431
432
433
434
435
436
437
438
    ):
        if inputs["type"] == "token":
            # In case the input processor for that model fails to set it
            if "mm_processor_kwargs" not in inputs:
                inputs["mm_processor_kwargs"] = mm_processor_kwargs
        elif inputs["type"] == "multimodal":
            # Be more strict in V2
            assert "mm_kwargs" in inputs
        else:
439
            assert_never(inputs["type"])  # type: ignore[arg-type]
440

441
    def process_input(self, model_config: "ModelConfig",
442
                      inputs: ProcessorInputs) -> ProcessorInputs:
443
444
445
446
447
448
449
450
451
        """
        Apply an input processor to an instance of model inputs.

        The model is identified by ``model_config``.
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
452
        processor = self._get_model_input_processor(model_cls)
453

454
455
456
457
458
        # Handle multimodal processor kwargs with priority:
        #     Inference kwargs -> Init kwargs -> {}
        # If it's empty, it'll fall back to the default kwarg values
        mm_processor_kwargs = resolve_mm_processor_kwargs(
            model_config.mm_processor_kwargs,
459
            inputs.get("mm_processor_kwargs", {}),  # type: ignore
460
461
            processor,
        )
462

463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
        processed_inputs = processor(
            InputContext(model_config),
            inputs,
            **mm_processor_kwargs,
        )

        if is_encoder_decoder_inputs(processed_inputs):
            self._ensure_mm_kwargs(processed_inputs["encoder"],
                                   mm_processor_kwargs)
            self._ensure_mm_kwargs(processed_inputs["decoder"],
                                   mm_processor_kwargs)
        else:
            self._ensure_mm_kwargs(processed_inputs, mm_processor_kwargs)

        return processed_inputs
478
479
480

    def create_input_processor(self, model_config: "ModelConfig"):
        """
481
        Create an input processor (see :meth:`_process_input`) for a
482
483
484
        specific model.
        """
        return functools.partial(self.process_input, model_config)