registry.py 15 KB
Newer Older
1
import functools
2
from collections import UserDict
3
from dataclasses import dataclass
4
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple,
5
                    Optional, Protocol, Type)
6
7

from torch import nn
8
9
from transformers import PretrainedConfig, ProcessorMixin
from typing_extensions import TypeVar, assert_never
10
11

from vllm.logger import init_logger
12
13
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer
14
15
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
                        print_warning_once, resolve_mm_processor_kwargs)
16

17
18
from .data import ProcessorInputs, SingletonInputs
from .parse import is_encoder_decoder_inputs
19
20

if TYPE_CHECKING:
21
    from vllm.config import ModelConfig
22
23
    from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict,
                                 MultiModalRegistry)
24
25
26
27
    from vllm.sequence import SequenceData

logger = init_logger(__name__)

28
C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig)
29
30
31
32
33
34
35
36
37
38
39
40


@dataclass(frozen=True)
class InputContext:
    """
    Contains information about the model which may be used to
    modify the inputs.
    """

    model_config: "ModelConfig"
    """The configuration of the model."""

41
    def get_hf_config(self, hf_config_type: Type[C] = PretrainedConfig) -> C:
42
43
44
45
46
47
        """
        Get the HuggingFace configuration
        (:class:`transformers.PretrainedConfig`) of the model,
        additionally checking its type.

        Raises:
48
            TypeError: If the model is not of the specified type.
49
50
51
52
53
54
55
56
57
        """
        hf_config = self.model_config.hf_config
        if not isinstance(hf_config, hf_config_type):
            raise TypeError("Invalid type of HuggingFace config. "
                            f"Expected type: {hf_config_type}, but "
                            f"found type: {type(hf_config)}")

        return hf_config

58
59
60
61
62
63
    def get_hf_image_processor_config(self) -> Dict[str, Any]:
        """
        Get the HuggingFace image processor configuration of the model.
        """
        return self.model_config.hf_image_processor_config

64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    def get_mm_config(self):
        """
        Get the multimodal config of the model.

        Raises:
            RuntimeError: If the model is not a multimodal model.
        """
        mm_config = self.model_config.multimodal_config
        if mm_config is None:
            raise RuntimeError("Not a multimodal model")

        return mm_config

    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
        base_kwargs = self.model_config.mm_processor_kwargs
        if base_kwargs is None:
            base_kwargs = {}

        merged_kwargs = {**base_kwargs, **kwargs}

        return cached_get_processor(
            self.model_config.model,
            trust_remote_code=self.model_config.trust_remote_code,
            **merged_kwargs,
        )

90

91
92
93
94
95
@dataclass(frozen=True)
class InputProcessingContext(InputContext):
    tokenizer: AnyTokenizer
    """The tokenizer used to tokenize the inputs."""

96
97
98
99
100
101
102
    def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
        base_kwargs = self.model_config.mm_processor_kwargs
        if base_kwargs is None:
            base_kwargs = {}

        merged_kwargs = {**base_kwargs, **kwargs}

103
        return cached_get_processor(
104
            self.model_config.model,
105
106
            tokenizer=self.tokenizer,  # Override the tokenizer with ours
            trust_remote_code=self.model_config.trust_remote_code,
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
            **merged_kwargs,
        )

    def resolve_hf_processor_call_kwargs(
        self,
        hf_processor: ProcessorMixin,
        inference_kwargs: Mapping[str, object],
    ) -> Mapping[str, object]:
        assert callable(hf_processor)

        base_kwargs = self.model_config.mm_processor_kwargs
        if base_kwargs is None:
            base_kwargs = {}

        return resolve_mm_processor_kwargs(
            base_kwargs,
            inference_kwargs,
            hf_processor,
        )
126
127


128
129
130
N = TypeVar("N", bound=Type[nn.Module])


131
132
133
134
135
136
137
138
class DummyData(NamedTuple):
    """Dummy data used for profiling."""

    seq_data: "SequenceData"
    multi_modal_data: Optional["MultiModalDataDict"] = None
    multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None


139
140
141
142
143
144
145
class DummyDataFactory(Protocol):

    def __call__(
        self,
        ctx: InputContext,
        seq_len: int,
        mm_counts: Mapping[str, int],
146
        **mm_processor_kwargs: Any,
147
    ) -> DummyData:
148
149
150
151
152
        """
        Create dummy data to be inputted into the model.

        Note:
            :data:`InputProcessor` is not applied to the dummy data.
153
154
155
156

            The :code:`mm_processor_kwargs` are overrides provided at
            initialization time to values in the config whose values
            may affect the number of tokens per instance.
157
158
159
160
        """
        ...


161
class _MultiModalCounts(UserDict[str, int]):
162
163
164
165
166
167
168
169
170
171
172
173
174
    """
    Wraps `mm_counts` for a more informative error message
    when attempting to access a plugin that does not exist.
    """

    def __getitem__(self, key: str) -> int:
        try:
            return super().__getitem__(key)
        except KeyError as exc:
            msg = (f"There is no multi-modal plugin with the key: {key}. "
                   f"Available keys: {set(self.keys())}")
            raise KeyError(msg) from exc

175

176
InputProcessor = Callable[[InputContext, ProcessorInputs], ProcessorInputs]
177
178
179
180
181
182
183
184
185
186
"""Preprocess the inputs to the model."""


class InputRegistry:
    """
    A registry to dispatch data processing
    according to the target model.
    """

    def __init__(self) -> None:
187
188
189
190
191
192
        self._dummy_factories_by_model_type = \
            ClassRegistry[nn.Module, DummyDataFactory]()
        self._dummy_encoder_factories_by_model_type = \
            ClassRegistry[nn.Module, DummyDataFactory]()
        self._input_processors_by_model_type = \
            ClassRegistry[nn.Module, InputProcessor]()
193
194
195
196
197

    def _default_dummy_data_factory(
        self,
        ctx: InputContext,
        seq_len: int,
198
        mm_counts: Mapping[str, int],
199
    ) -> DummyData:
200
201
202
203
204
205
206
207
208
209
        """
        The default dummy data factory represents the longest possible text
        that can be inputted to the model.

        Note:
            :data:`InputProcessor` is not applied to the dummy data.
        """
        # Avoid circular import
        from vllm.sequence import SequenceData

210
        return DummyData(SequenceData.from_prompt_token_counts((0, seq_len)))
211
212
213
214
215
216
217
218
219
220
221

    def register_dummy_data(self, factory: DummyDataFactory):
        """
        Register a dummy data factory to a model class.

        During memory profiling, the provided function is invoked to create
        dummy data to be inputted into the model. The resulting memory usage
        should be an upper bound of what the model would use at inference time.
        """

        def wrapper(model_cls: N) -> N:
222
223
            if self._dummy_factories_by_model_type.contains(model_cls,
                                                            strict=True):
224
225
226
227
228
229
230
231
232
233
234
                logger.warning(
                    "Model class %s already has dummy data "
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

            self._dummy_factories_by_model_type[model_cls] = factory

            return model_cls

        return wrapper

235
236
237
238
    def _get_dummy_data_factory(self, model_cls: Type[nn.Module]):
        return self._dummy_factories_by_model_type \
            .get(model_cls, self._default_dummy_data_factory)

239
240
241
242
243
244
245
246
    def register_dummy_encoder_data(self, factory: DummyDataFactory):
        """
        Register a dummy encoder data factory to a model class

        This is similar to :meth:`~register_dummy_data`, but for encoder input.
        """

        def wrapper(model_cls: N) -> N:
247
248
            if self._dummy_encoder_factories_by_model_type.contains(
                    model_cls, strict=True):
249
250
251
252
253
254
255
256
257
258
259
260
                logger.warning(
                    "Model class %s already has dummy encoder data "
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

            self._dummy_encoder_factories_by_model_type[model_cls] = factory

            return model_cls

        return wrapper

    def _get_dummy_encoder_data_factory(self, model_cls: Type[nn.Module]):
261
262
        return self._dummy_encoder_factories_by_model_type \
            .get(model_cls, self._default_dummy_data_factory)
263

264
265
266
267
268
    def dummy_data_for_profiling(
        self,
        model_config: "ModelConfig",
        seq_len: int,
        mm_registry: "MultiModalRegistry",
269
        is_encoder_data: bool = False,
270
    ) -> DummyData:
271
272
273
274
275
        """
        Create dummy data for profiling the memory usage of a model.

        The model is identified by ``model_config``.

276
        See also:
277
            :ref:`enabling_multimodal_inputs`
278
279
280
281

        Note:
            This should be called after
            :meth:`~MultiModalRegistry.init_mm_limits_per_prompt`.
282
283
284
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
        from vllm.multimodal import MultiModalKwargs
        from vllm.multimodal.utils import cached_get_tokenizer

        if mm_registry.has_processor(model_config):
            tokenizer = cached_get_tokenizer(
                model_config.tokenizer,
                trust_remote_code=model_config.trust_remote_code,
            )
            processor = mm_registry.create_processor(model_config, tokenizer)

            mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
            mm_max_tokens = mm_registry.get_max_tokens_by_modality(
                model_config)

            dummy_data = processor.get_dummy_data(seq_len, mm_counts,
                                                  mm_max_tokens)
301
        else:
302
303
304
305
306
307
308
309
            model_cls, _ = get_model_architecture(model_config)
            if is_encoder_data:
                dummy_factory = self._get_dummy_encoder_data_factory(model_cls)
            else:
                dummy_factory = self._get_dummy_data_factory(model_cls)
            mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
            mm_processor_kwargs = get_allowed_kwarg_only_overrides(
                dummy_factory, overrides=model_config.mm_processor_kwargs)
310

311
312
313
            dummy_data = dummy_factory(InputContext(model_config), seq_len,
                                       _MultiModalCounts(mm_counts),
                                       **mm_processor_kwargs)
314
315

        # Having more tokens is over-conservative but otherwise fine
316
        num_tokens = dummy_data.seq_data.prompt_token_ids
317
318
        if len(num_tokens) < seq_len:
            if is_encoder_data:
319
320
321
                print_warning_once(
                    f"Expected at least {seq_len} dummy encoder tokens for "
                    f"profiling, but found {len(num_tokens)} tokens instead.")
322
323
324
325
            else:
                raise AssertionError(
                    f"Expected at least {seq_len} dummy tokens for profiling, "
                    f"but found {len(num_tokens)} tokens instead.")
326
327
328

        if (dummy_data.multi_modal_data is not None and
                not isinstance(dummy_data.multi_modal_data, MultiModalKwargs)):
329
            for k, v in dummy_data.multi_modal_data.items():
330
331
332
333
334
335
                num_items = len(v) if isinstance(v, list) else 1
                num_expected = mm_counts[k]
                assert num_items >= num_expected, (
                    f"Expected at least {num_expected} dummy '{k}' instances "
                    f"for profiling, but found {num_items} instances instead.")

336
        return dummy_data
337

338
339
340
    def _default_input_processor(
        self,
        ctx: InputContext,
341
342
        inputs: ProcessorInputs,
    ) -> ProcessorInputs:
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
        """The default input processor is a no-op."""
        return inputs

    def register_input_processor(self, processor: InputProcessor):
        """
        Register an input processor to a model class.

        The provided function is invoked on each input to the model. This
        happens before :meth:`~vllm.multimodal.MultiModalRegistry.map_input`.

        See also:
            :ref:`input_processing_pipeline`
        """

        def wrapper(model_cls: N) -> N:
358
359
            if self._input_processors_by_model_type.contains(model_cls,
                                                             strict=True):
360
361
362
363
364
365
366
367
368
369
370
                logger.warning(
                    "Model class %s already has input processor "
                    "registered to %s. It is overwritten by the new one.",
                    model_cls, self)

            self._input_processors_by_model_type[model_cls] = processor

            return model_cls

        return wrapper

371
372
373
374
    def _get_model_input_processor(self, model_cls: Type[nn.Module]):
        return self._input_processors_by_model_type \
            .get(model_cls, self._default_input_processor)

375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
    def _ensure_mm_kwargs(
        self,
        inputs: SingletonInputs,
        mm_processor_kwargs: Dict[str, Any],
    ):
        if inputs["type"] == "token":
            # In case the input processor for that model fails to set it
            if "mm_processor_kwargs" not in inputs:
                inputs["mm_processor_kwargs"] = mm_processor_kwargs
        elif inputs["type"] == "multimodal":
            # Be more strict in V2
            assert "mm_kwargs" in inputs
        else:
            assert_never(inputs["type"])

390
    def process_input(self, model_config: "ModelConfig",
391
                      inputs: ProcessorInputs) -> ProcessorInputs:
392
393
394
395
396
397
398
399
400
401
402
403
        """
        Apply an input processor to an instance of model inputs.

        The model is identified by ``model_config``.

        See also:
            :ref:`input_processing_pipeline`
        """
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
404
        processor = self._get_model_input_processor(model_cls)
405

406
407
408
409
410
        # Handle multimodal processor kwargs with priority:
        #     Inference kwargs -> Init kwargs -> {}
        # If it's empty, it'll fall back to the default kwarg values
        mm_processor_kwargs = resolve_mm_processor_kwargs(
            model_config.mm_processor_kwargs,
411
            inputs.get("mm_processor_kwargs", {}),  # type: ignore
412
413
            processor,
        )
414

415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
        processed_inputs = processor(
            InputContext(model_config),
            inputs,
            **mm_processor_kwargs,
        )

        if is_encoder_decoder_inputs(processed_inputs):
            self._ensure_mm_kwargs(processed_inputs["encoder"],
                                   mm_processor_kwargs)
            self._ensure_mm_kwargs(processed_inputs["decoder"],
                                   mm_processor_kwargs)
        else:
            self._ensure_mm_kwargs(processed_inputs, mm_processor_kwargs)

        return processed_inputs
430
431
432

    def create_input_processor(self, model_config: "ModelConfig"):
        """
433
        Create an input processor (see :meth:`_process_input`) for a
434
435
436
        specific model.
        """
        return functools.partial(self.process_input, model_config)