registry.py 7.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Mapping
4
from dataclasses import dataclass
5
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
6

7
import torch
8
from packaging.version import Version
9
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
10
from transformers import __version__ as TRANSFORMERS_VERSION
11
from typing_extensions import TypeVar
12

13
14
from vllm.jsontree import JSONTree, json_map_leaves
from vllm.logger import init_logger
15
from vllm.transformers_utils.processor import cached_processor_from_config
16
from vllm.utils import resolve_mm_processor_kwargs
17
18

if TYPE_CHECKING:
19
    from vllm.config import ModelConfig
20
21
    from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict,
                                 MultiModalRegistry)
22
    from vllm.sequence import SequenceData
23
24
25
26
27
28
29
30
    from vllm.transformers_utils.tokenizer import AnyTokenizer
else:
    ModelConfig = Any
    MultiModalDataDict = Any
    MultiModalPlaceholderDict = Any
    MultiModalRegistry = Any
    SequenceData = Any
    AnyTokenizer = Any
31

32
33
34
_T = TypeVar("_T")
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
35

36
37
logger = init_logger(__name__)

38

39
40
41
42
43
44
45
@dataclass(frozen=True)
class InputContext:
    """
    Contains information about the model which may be used to
    modify the inputs.
    """

46
    model_config: ModelConfig
47
48
    """The configuration of the model."""

49
50
    def get_hf_config(
        self,
51
        typ: Union[type[_C], tuple[type[_C], ...]] = PretrainedConfig,
52
        /,
53
    ) -> _C:
54
55
        """
        Get the HuggingFace configuration
56
        (`transformers.PretrainedConfig`) of the model,
57
58
59
        additionally checking its type.

        Raises:
60
            TypeError: If the configuration is not of the specified type.
61
62
        """
        hf_config = self.model_config.hf_config
63
        if not isinstance(hf_config, typ):
64
            raise TypeError("Invalid type of HuggingFace config. "
65
                            f"Expected type: {typ}, but "
66
67
68
69
                            f"found type: {type(hf_config)}")

        return hf_config

70
    def get_hf_image_processor_config(self) -> dict[str, Any]:
71
72
73
74
75
        """
        Get the HuggingFace image processor configuration of the model.
        """
        return self.model_config.hf_image_processor_config

76
77
78
79
80
81
82
83
84
85
86
87
88
    def get_mm_config(self):
        """
        Get the multimodal config of the model.

        Raises:
            RuntimeError: If the model is not a multimodal model.
        """
        mm_config = self.model_config.multimodal_config
        if mm_config is None:
            raise RuntimeError("Not a multimodal model")

        return mm_config

89
90
    def get_hf_processor(
        self,
91
        typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
92
93
        /,
        **kwargs: object,
94
    ) -> _P:
95
96
        """
        Get the HuggingFace processor
97
        (`transformers.ProcessorMixin`) of the model,
98
99
100
101
102
        additionally checking its type.

        Raises:
            TypeError: If the processor is not of the specified type.
        """
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        return cached_processor_from_config(
            self.model_config,
            processor_cls=typ,
            **kwargs,
        )

    def init_processor(
        self,
        typ: type[_T],
        /,
        **kwargs: object,
    ) -> _T:
        """
        Initialize a HuggingFace-like processor class, merging the
        keyword arguments with those in the model's configuration.
        """
119
120
        mm_config = self.model_config.get_multimodal_config()
        base_kwargs = mm_config.mm_processor_kwargs
121
122
123
124
125
        if base_kwargs is None:
            base_kwargs = {}

        merged_kwargs = {**base_kwargs, **kwargs}

126
        return typ(**merged_kwargs)
127

128

129
130
131
132
133
@dataclass(frozen=True)
class InputProcessingContext(InputContext):
    tokenizer: AnyTokenizer
    """The tokenizer used to tokenize the inputs."""

134
135
    def get_hf_processor(
        self,
136
        typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
137
138
        /,
        **kwargs: object,
139
    ) -> _P:
140
141
142
143
144
        # Transformers 4.53.0 has issue with passing tokenizer to
        # initialize processor. We disable it for this version.
        # See: https://github.com/vllm-project/vllm/issues/20224
        if Version(TRANSFORMERS_VERSION) != Version("4.53.0"):
            kwargs["tokenizer"] = self.tokenizer
145
146
147
        return super().get_hf_processor(
            typ,
            **kwargs,
148
149
        )

150
    def call_hf_processor(
151
152
        self,
        hf_processor: ProcessorMixin,
153
154
        data: Mapping[str, object],
        kwargs: Mapping[str, object] = {},
155
    ) -> Union[BatchFeature, JSONTree]:
156
        """
157
158
        Call `hf_processor` on the prompt `data`
        (text, image, audio...) with configurable options `kwargs`.
159
        """
160
161
        assert callable(hf_processor)

162
163
        mm_config = self.model_config.get_multimodal_config()
        base_kwargs = mm_config.mm_processor_kwargs
164
165
166
        if base_kwargs is None:
            base_kwargs = {}

167
        merged_kwargs = resolve_mm_processor_kwargs(
168
            base_kwargs,
169
            kwargs,
170
            hf_processor,
171
172
            requires_kw_only=False,
            allow_var_kwargs=True,
173
        )
174

175
176
177
178
179
180
        def maybe_cast_dtype(x):
            # This mimics the behavior of transformers.BatchFeature
            if isinstance(x, torch.Tensor) and x.is_floating_point():
                return x.to(dtype=self.model_config.dtype)
            return x

181
        try:
182
183
184
            output = hf_processor(**data, **merged_kwargs, return_tensors="pt")
            # this emulates output.to(dtype=self.model_config.dtype)
            if isinstance(output, BatchFeature):
185
                cast_output = json_map_leaves(maybe_cast_dtype, output.data)
186
187
                return BatchFeature(cast_output)

188
189
            cast_output = json_map_leaves(maybe_cast_dtype, output)

190
191
192
193
194
195
            logger.warning_once(
                f"{type(hf_processor).__name__} did not return `BatchFeature`. "
                "Make sure to match the behaviour of `ProcessorMixin` when "
                "implementing custom processors.")
            return cast_output

196
197
198
199
        except Exception as exc:
            msg = (f"Failed to apply {type(hf_processor).__name__} "
                   f"on data={data} with kwargs={merged_kwargs}")

200
            raise ValueError(msg) from exc
201

202

203
class DummyData(NamedTuple):
204
205
206
207
208
    """
    Dummy data used for profiling.

    Note: This is only used in V0.
    """
209

210
211
212
    seq_data: SequenceData
    multi_modal_data: Optional[MultiModalDataDict] = None
    multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
213
214


215
216
class InputRegistry:
    """
217
    Note: This is only used in V0.
218
219
    """

220
221
    def dummy_data_for_profiling(
        self,
222
        model_config: ModelConfig,
223
        seq_len: int,
224
        mm_registry: MultiModalRegistry,
225
        is_encoder_data: bool = False,
226
    ) -> DummyData:
227
228
229
230
231
232
        """
        Create dummy data for profiling the memory usage of a model.

        The model is identified by ``model_config``.
        """
        # Avoid circular import
233
        from vllm.sequence import SequenceData
234

235
236
237
        if not model_config.is_multimodal_model:
            seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
            return DummyData(seq_data=seq_data)
238

239
240
241
242
243
244
        # Encoder dummy data does not contain multi-modal data
        if is_encoder_data:
            enc_data = mm_registry.get_encoder_dummy_data(
                model_config, seq_len)
            seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids)
            return DummyData(seq_data=seq_data)
245

246
        dec_data = mm_registry.get_decoder_dummy_data(model_config, seq_len)
247

248
249
250
251
        return DummyData(
            seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids),
            multi_modal_data=dec_data.multi_modal_data,
            multi_modal_placeholders=dec_data.multi_modal_placeholders,
252
        )