registry.py 7.54 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Mapping
4
from dataclasses import dataclass
5
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
6

7
import torch
8
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
9
from typing_extensions import TypeVar
10

11
12
from vllm.jsontree import JSONTree, json_map_leaves
from vllm.logger import init_logger
13
from vllm.transformers_utils.processor import cached_processor_from_config
14
from vllm.utils import resolve_mm_processor_kwargs
15
16

if TYPE_CHECKING:
17
    from vllm.config import ModelConfig
18
19
    from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict,
                                 MultiModalRegistry)
20
    from vllm.sequence import SequenceData
21
22
23
24
25
26
27
28
    from vllm.transformers_utils.tokenizer import AnyTokenizer
else:
    ModelConfig = Any
    MultiModalDataDict = Any
    MultiModalPlaceholderDict = Any
    MultiModalRegistry = Any
    SequenceData = Any
    AnyTokenizer = Any
29

30
31
32
_T = TypeVar("_T")
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
33

34
35
logger = init_logger(__name__)

36

37
38
39
40
41
42
43
@dataclass(frozen=True)
class InputContext:
    """
    Contains information about the model which may be used to
    modify the inputs.
    """

44
    model_config: ModelConfig
45
46
    """The configuration of the model."""

47
48
    def get_hf_config(
        self,
49
        typ: Union[type[_C], tuple[type[_C], ...]] = PretrainedConfig,
50
        /,
51
    ) -> _C:
52
53
        """
        Get the HuggingFace configuration
54
        (`transformers.PretrainedConfig`) of the model,
55
56
57
        additionally checking its type.

        Raises:
58
            TypeError: If the configuration is not of the specified type.
59
60
        """
        hf_config = self.model_config.hf_config
61
        if not isinstance(hf_config, typ):
62
            raise TypeError("Invalid type of HuggingFace config. "
63
                            f"Expected type: {typ}, but "
64
65
66
67
                            f"found type: {type(hf_config)}")

        return hf_config

68
    def get_hf_image_processor_config(self) -> dict[str, Any]:
69
70
71
72
73
        """
        Get the HuggingFace image processor configuration of the model.
        """
        return self.model_config.hf_image_processor_config

74
75
76
77
78
79
80
81
82
83
84
85
86
    def get_mm_config(self):
        """
        Get the multimodal config of the model.

        Raises:
            RuntimeError: If the model is not a multimodal model.
        """
        mm_config = self.model_config.multimodal_config
        if mm_config is None:
            raise RuntimeError("Not a multimodal model")

        return mm_config

87
88
    def get_hf_processor(
        self,
89
        typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
90
91
        /,
        **kwargs: object,
92
    ) -> _P:
93
94
        """
        Get the HuggingFace processor
95
        (`transformers.ProcessorMixin`) of the model,
96
97
98
99
100
        additionally checking its type.

        Raises:
            TypeError: If the processor is not of the specified type.
        """
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        return cached_processor_from_config(
            self.model_config,
            processor_cls=typ,
            **kwargs,
        )

    def init_processor(
        self,
        typ: type[_T],
        /,
        **kwargs: object,
    ) -> _T:
        """
        Initialize a HuggingFace-like processor class, merging the
        keyword arguments with those in the model's configuration.
        """
117
118
        mm_config = self.model_config.get_multimodal_config()
        base_kwargs = mm_config.mm_processor_kwargs
119
120
121
122
123
        if base_kwargs is None:
            base_kwargs = {}

        merged_kwargs = {**base_kwargs, **kwargs}

124
        return typ(**merged_kwargs)
125

126

127
128
129
130
131
@dataclass(frozen=True)
class InputProcessingContext(InputContext):
    tokenizer: AnyTokenizer
    """The tokenizer used to tokenize the inputs."""

132
133
    def get_hf_processor(
        self,
134
        typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
135
136
        /,
        **kwargs: object,
137
    ) -> _P:
138
139
        return super().get_hf_processor(
            typ,
140
            tokenizer=self.tokenizer,
141
            **kwargs,
142
143
        )

144
    def call_hf_processor(
145
146
        self,
        hf_processor: ProcessorMixin,
147
148
        data: Mapping[str, object],
        kwargs: Mapping[str, object] = {},
149
    ) -> Union[BatchFeature, JSONTree]:
150
        """
151
152
        Call `hf_processor` on the prompt `data`
        (text, image, audio...) with configurable options `kwargs`.
153
        """
154
155
        assert callable(hf_processor)

156
157
        mm_config = self.model_config.get_multimodal_config()
        base_kwargs = mm_config.mm_processor_kwargs
158
159
160
        if base_kwargs is None:
            base_kwargs = {}

161
        merged_kwargs = resolve_mm_processor_kwargs(
162
            base_kwargs,
163
            kwargs,
164
            hf_processor,
165
166
            requires_kw_only=False,
            allow_var_kwargs=True,
167
        )
168

169
170
171
172
173
174
        def maybe_cast_dtype(x):
            # This mimics the behavior of transformers.BatchFeature
            if isinstance(x, torch.Tensor) and x.is_floating_point():
                return x.to(dtype=self.model_config.dtype)
            return x

175
        try:
176
177
178
            output = hf_processor(**data, **merged_kwargs, return_tensors="pt")
            # this emulates output.to(dtype=self.model_config.dtype)
            if isinstance(output, BatchFeature):
179
                cast_output = json_map_leaves(maybe_cast_dtype, output.data)
180
181
                return BatchFeature(cast_output)

182
183
            cast_output = json_map_leaves(maybe_cast_dtype, output)

184
185
186
187
188
189
            logger.warning_once(
                f"{type(hf_processor).__name__} did not return `BatchFeature`. "
                "Make sure to match the behaviour of `ProcessorMixin` when "
                "implementing custom processors.")
            return cast_output

190
191
192
193
        except Exception as exc:
            msg = (f"Failed to apply {type(hf_processor).__name__} "
                   f"on data={data} with kwargs={merged_kwargs}")

194
            raise ValueError(msg) from exc
195

196

197
class DummyData(NamedTuple):
198
199
200
201
202
    """
    Dummy data used for profiling.

    Note: This is only used in V0.
    """
203

204
205
206
    seq_data: SequenceData
    multi_modal_data: Optional[MultiModalDataDict] = None
    multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
207
208


209
210
class InputRegistry:
    """
211
    Note: This is only used in V0.
212
213
    """

214
215
    def dummy_data_for_profiling(
        self,
216
        model_config: ModelConfig,
217
        seq_len: int,
218
        mm_registry: MultiModalRegistry,
219
        is_encoder_data: bool = False,
220
    ) -> DummyData:
221
222
223
224
225
226
        """
        Create dummy data for profiling the memory usage of a model.

        The model is identified by ``model_config``.
        """
        # Avoid circular import
227
        from vllm.sequence import SequenceData
228

229
230
231
        if not model_config.is_multimodal_model:
            seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
            return DummyData(seq_data=seq_data)
232

233
234
235
236
237
238
        # Encoder dummy data does not contain multi-modal data
        if is_encoder_data:
            enc_data = mm_registry.get_encoder_dummy_data(
                model_config, seq_len)
            seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids)
            return DummyData(seq_data=seq_data)
239

240
        dec_data = mm_registry.get_decoder_dummy_data(model_config, seq_len)
241

242
243
244
245
        return DummyData(
            seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids),
            multi_modal_data=dec_data.multi_modal_data,
            multi_modal_placeholders=dec_data.multi_modal_placeholders,
246
        )