registry.py 7.31 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Mapping
4
from dataclasses import dataclass
5
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
6

7
import torch
8
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
9
from typing_extensions import TypeVar
10

11
12
from vllm.jsontree import JSONTree, json_map_leaves
from vllm.logger import init_logger
13
from vllm.transformers_utils.processor import cached_processor_from_config
14
from vllm.transformers_utils.tokenizer import AnyTokenizer
15
from vllm.utils import resolve_mm_processor_kwargs
16
17

if TYPE_CHECKING:
18
    from vllm.config import ModelConfig
19
20
    from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict,
                                 MultiModalRegistry)
21
22
    from vllm.sequence import SequenceData

23
24
25
_T = TypeVar("_T")
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
26

27
28
logger = init_logger(__name__)

29

30
31
32
33
34
35
36
37
38
39
@dataclass(frozen=True)
class InputContext:
    """
    Contains information about the model which may be used to
    modify the inputs.
    """

    model_config: "ModelConfig"
    """The configuration of the model."""

40
41
    def get_hf_config(
        self,
42
        typ: Union[type[_C], tuple[type[_C], ...]] = PretrainedConfig,
43
        /,
44
    ) -> _C:
45
46
        """
        Get the HuggingFace configuration
47
        (`transformers.PretrainedConfig`) of the model,
48
49
50
        additionally checking its type.

        Raises:
51
            TypeError: If the configuration is not of the specified type.
52
53
        """
        hf_config = self.model_config.hf_config
54
        if not isinstance(hf_config, typ):
55
            raise TypeError("Invalid type of HuggingFace config. "
56
                            f"Expected type: {typ}, but "
57
58
59
60
                            f"found type: {type(hf_config)}")

        return hf_config

61
    def get_hf_image_processor_config(self) -> dict[str, Any]:
62
63
64
65
66
        """
        Get the HuggingFace image processor configuration of the model.
        """
        return self.model_config.hf_image_processor_config

67
68
69
70
71
72
73
74
75
76
77
78
79
    def get_mm_config(self):
        """
        Get the multimodal config of the model.

        Raises:
            RuntimeError: If the model is not a multimodal model.
        """
        mm_config = self.model_config.multimodal_config
        if mm_config is None:
            raise RuntimeError("Not a multimodal model")

        return mm_config

80
81
    def get_hf_processor(
        self,
82
        typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
83
84
        /,
        **kwargs: object,
85
    ) -> _P:
86
87
        """
        Get the HuggingFace processor
88
        (`transformers.ProcessorMixin`) of the model,
89
90
91
92
93
        additionally checking its type.

        Raises:
            TypeError: If the processor is not of the specified type.
        """
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        return cached_processor_from_config(
            self.model_config,
            processor_cls=typ,
            **kwargs,
        )

    def init_processor(
        self,
        typ: type[_T],
        /,
        **kwargs: object,
    ) -> _T:
        """
        Initialize a HuggingFace-like processor class, merging the
        keyword arguments with those in the model's configuration.
        """
110
111
        mm_config = self.model_config.get_multimodal_config()
        base_kwargs = mm_config.mm_processor_kwargs
112
113
114
115
116
        if base_kwargs is None:
            base_kwargs = {}

        merged_kwargs = {**base_kwargs, **kwargs}

117
        return typ(**merged_kwargs)
118

119

120
121
122
123
124
@dataclass(frozen=True)
class InputProcessingContext(InputContext):
    tokenizer: AnyTokenizer
    """The tokenizer used to tokenize the inputs."""

125
126
    def get_hf_processor(
        self,
127
        typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
128
129
        /,
        **kwargs: object,
130
    ) -> _P:
131
132
133
134
        return super().get_hf_processor(
            typ,
            tokenizer=self.tokenizer,
            **kwargs,
135
136
        )

137
    def call_hf_processor(
138
139
        self,
        hf_processor: ProcessorMixin,
140
141
        data: Mapping[str, object],
        kwargs: Mapping[str, object] = {},
142
    ) -> Union[BatchFeature, JSONTree]:
143
        """
144
145
        Call `hf_processor` on the prompt `data`
        (text, image, audio...) with configurable options `kwargs`.
146
        """
147
148
        assert callable(hf_processor)

149
150
        mm_config = self.model_config.get_multimodal_config()
        base_kwargs = mm_config.mm_processor_kwargs
151
152
153
        if base_kwargs is None:
            base_kwargs = {}

154
        merged_kwargs = resolve_mm_processor_kwargs(
155
            base_kwargs,
156
            kwargs,
157
            hf_processor,
158
159
            requires_kw_only=False,
            allow_var_kwargs=True,
160
        )
161

162
163
164
165
166
167
        def maybe_cast_dtype(x):
            # This mimics the behavior of transformers.BatchFeature
            if isinstance(x, torch.Tensor) and x.is_floating_point():
                return x.to(dtype=self.model_config.dtype)
            return x

168
        try:
169
170
171
172
173
174
175
176
177
178
179
180
            output = hf_processor(**data, **merged_kwargs, return_tensors="pt")
            # this emulates output.to(dtype=self.model_config.dtype)
            cast_output = json_map_leaves(maybe_cast_dtype, output)
            if isinstance(output, BatchFeature):
                return BatchFeature(cast_output)

            logger.warning_once(
                f"{type(hf_processor).__name__} did not return `BatchFeature`. "
                "Make sure to match the behaviour of `ProcessorMixin` when "
                "implementing custom processors.")
            return cast_output

181
182
183
184
        except Exception as exc:
            msg = (f"Failed to apply {type(hf_processor).__name__} "
                   f"on data={data} with kwargs={merged_kwargs}")

185
            raise ValueError(msg) from exc
186

187

188
class DummyData(NamedTuple):
189
190
191
192
193
    """
    Dummy data used for profiling.

    Note: This is only used in V0.
    """
194
195
196
197
198
199

    seq_data: "SequenceData"
    multi_modal_data: Optional["MultiModalDataDict"] = None
    multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None


200
201
class InputRegistry:
    """
202
    Note: This is only used in V0.
203
204
    """

205
206
207
208
209
    def dummy_data_for_profiling(
        self,
        model_config: "ModelConfig",
        seq_len: int,
        mm_registry: "MultiModalRegistry",
210
        is_encoder_data: bool = False,
211
    ) -> DummyData:
212
213
214
215
216
217
        """
        Create dummy data for profiling the memory usage of a model.

        The model is identified by ``model_config``.
        """
        # Avoid circular import
218
        from vllm.sequence import SequenceData
219

220
221
222
        if not model_config.is_multimodal_model:
            seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
            return DummyData(seq_data=seq_data)
223

224
225
226
227
228
229
        # Encoder dummy data does not contain multi-modal data
        if is_encoder_data:
            enc_data = mm_registry.get_encoder_dummy_data(
                model_config, seq_len)
            seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids)
            return DummyData(seq_data=seq_data)
230

231
        dec_data = mm_registry.get_decoder_dummy_data(model_config, seq_len)
232

233
234
235
236
        return DummyData(
            seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids),
            multi_modal_data=dec_data.multi_modal_data,
            multi_modal_placeholders=dec_data.multi_modal_placeholders,
237
        )