"vllm/vscode:/vscode.git/clone" did not exist on "2836dd73f13015ee386c544760ca0d16888203f3"
registry.py 6.27 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
from collections.abc import Mapping
3
from dataclasses import dataclass
4
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
5

6
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
7
from typing_extensions import TypeVar
8

9
from vllm.transformers_utils.processor import cached_processor_from_config
10
from vllm.transformers_utils.tokenizer import AnyTokenizer
11
from vllm.utils import resolve_mm_processor_kwargs
12
13

if TYPE_CHECKING:
14
    from vllm.config import ModelConfig
15
16
    from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict,
                                 MultiModalRegistry)
17
18
    from vllm.sequence import SequenceData

19
20
21
_T = TypeVar("_T")
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
22
23


24
25
26
27
28
29
30
31
32
33
@dataclass(frozen=True)
class InputContext:
    """
    Contains information about the model which may be used to
    modify the inputs.
    """

    model_config: "ModelConfig"
    """The configuration of the model."""

34
35
    def get_hf_config(
        self,
36
        typ: Union[type[_C], tuple[type[_C], ...]] = PretrainedConfig,
37
        /,
38
    ) -> _C:
39
40
41
42
43
44
        """
        Get the HuggingFace configuration
        (:class:`transformers.PretrainedConfig`) of the model,
        additionally checking its type.

        Raises:
45
            TypeError: If the configuration is not of the specified type.
46
47
        """
        hf_config = self.model_config.hf_config
48
        if not isinstance(hf_config, typ):
49
            raise TypeError("Invalid type of HuggingFace config. "
50
                            f"Expected type: {typ}, but "
51
52
53
54
                            f"found type: {type(hf_config)}")

        return hf_config

55
    def get_hf_image_processor_config(self) -> dict[str, Any]:
56
57
58
59
60
        """
        Get the HuggingFace image processor configuration of the model.
        """
        return self.model_config.hf_image_processor_config

61
62
63
64
65
66
67
68
69
70
71
72
73
    def get_mm_config(self):
        """
        Get the multimodal config of the model.

        Raises:
            RuntimeError: If the model is not a multimodal model.
        """
        mm_config = self.model_config.multimodal_config
        if mm_config is None:
            raise RuntimeError("Not a multimodal model")

        return mm_config

74
75
    def get_hf_processor(
        self,
76
        typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
77
78
        /,
        **kwargs: object,
79
    ) -> _P:
80
81
82
83
84
85
86
87
        """
        Get the HuggingFace processor
        (:class:`transformers.ProcessorMixin`) of the model,
        additionally checking its type.

        Raises:
            TypeError: If the processor is not of the specified type.
        """
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        return cached_processor_from_config(
            self.model_config,
            processor_cls=typ,
            **kwargs,
        )

    def init_processor(
        self,
        typ: type[_T],
        /,
        **kwargs: object,
    ) -> _T:
        """
        Initialize a HuggingFace-like processor class, merging the
        keyword arguments with those in the model's configuration.
        """
104
105
106
107
108
109
        base_kwargs = self.model_config.mm_processor_kwargs
        if base_kwargs is None:
            base_kwargs = {}

        merged_kwargs = {**base_kwargs, **kwargs}

110
        return typ(**merged_kwargs)
111

112

113
114
115
116
117
@dataclass(frozen=True)
class InputProcessingContext(InputContext):
    tokenizer: AnyTokenizer
    """The tokenizer used to tokenize the inputs."""

118
119
    def get_hf_processor(
        self,
120
        typ: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
121
122
        /,
        **kwargs: object,
123
    ) -> _P:
124
125
126
127
        return super().get_hf_processor(
            typ,
            tokenizer=self.tokenizer,
            **kwargs,
128
129
        )

130
    def call_hf_processor(
131
132
        self,
        hf_processor: ProcessorMixin,
133
134
        data: Mapping[str, object],
        kwargs: Mapping[str, object] = {},
135
    ) -> BatchFeature:
136
137
138
139
        """
        Call :code:`hf_processor` on the prompt :code:`data`
        (text, image, audio...) with configurable options :code:`kwargs`.
        """
140
141
142
143
144
145
        assert callable(hf_processor)

        base_kwargs = self.model_config.mm_processor_kwargs
        if base_kwargs is None:
            base_kwargs = {}

146
        merged_kwargs = resolve_mm_processor_kwargs(
147
            base_kwargs,
148
            kwargs,
149
            hf_processor,
150
151
            requires_kw_only=False,
            allow_var_kwargs=True,
152
        )
153

154
        try:
155
            return hf_processor(**data, **merged_kwargs, return_tensors="pt")
156
157
158
159
160
161
        except Exception as exc:
            msg = (f"Failed to apply {type(hf_processor).__name__} "
                   f"on data={data} with kwargs={merged_kwargs}")

            raise RuntimeError(msg) from exc

162

163
class DummyData(NamedTuple):
164
165
166
167
168
    """
    Dummy data used for profiling.

    Note: This is only used in V0.
    """
169
170
171
172
173
174

    seq_data: "SequenceData"
    multi_modal_data: Optional["MultiModalDataDict"] = None
    multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None


175
176
class InputRegistry:
    """
177
    Note: This is only used in V0.
178
179
    """

180
181
182
183
184
    def dummy_data_for_profiling(
        self,
        model_config: "ModelConfig",
        seq_len: int,
        mm_registry: "MultiModalRegistry",
185
        is_encoder_data: bool = False,
186
    ) -> DummyData:
187
188
189
190
191
192
        """
        Create dummy data for profiling the memory usage of a model.

        The model is identified by ``model_config``.
        """
        # Avoid circular import
193
        from vllm.sequence import SequenceData
194

195
196
197
        if not model_config.is_multimodal_model:
            seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
            return DummyData(seq_data=seq_data)
198

199
200
201
202
203
204
        # Encoder dummy data does not contain multi-modal data
        if is_encoder_data:
            enc_data = mm_registry.get_encoder_dummy_data(
                model_config, seq_len)
            seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids)
            return DummyData(seq_data=seq_data)
205

206
        dec_data = mm_registry.get_decoder_dummy_data(model_config, seq_len)
207

208
209
210
211
        return DummyData(
            seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids),
            multi_modal_data=dec_data.multi_modal_data,
            multi_modal_placeholders=dec_data.multi_modal_placeholders,
212
        )