processor.py 7.36 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from functools import lru_cache
4
from typing import TYPE_CHECKING, Any, Optional, Union, cast
5

6
from transformers.processing_utils import ProcessorMixin
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from typing_extensions import TypeVar

if TYPE_CHECKING:
    from vllm.config import ModelConfig

_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)


class HashableDict(dict):
    """
    A dictionary that can be hashed by lru_cache.
    """

    # NOTE: pythonic dict is not hashable,
    # we override on it directly for simplicity
    def __hash__(self) -> int:  # type: ignore[override]
        return hash(frozenset(self.items()))


26
27
28
29
30
31
32
33
34
class HashableList(list):
    """
    A list that can be hashed by lru_cache.
    """

    def __hash__(self) -> int:  # type: ignore[override]
        return hash(tuple(self))


35
def _merge_mm_kwargs(model_config: "ModelConfig", **kwargs):
36
37
    mm_config = model_config.get_multimodal_config()
    base_kwargs = mm_config.mm_processor_kwargs
38
39
40
41
42
43
44
45
46
47
48
    if base_kwargs is None:
        base_kwargs = {}

    merged_kwargs = {**base_kwargs, **kwargs}

    # NOTE: Pythonic dict is not hashable and will raise unhashable type
    # error when calling `cached_get_processor`, therefore we need to
    # wrap it to a hashable dict.
    for key, value in merged_kwargs.items():
        if isinstance(value, dict):
            merged_kwargs[key] = HashableDict(value)
49
50
        if isinstance(value, list):
            merged_kwargs[key] = HashableList(value)
51
    return merged_kwargs
52

53
54
55

def get_processor(
    processor_name: str,
56
    *args: Any,
57
    revision: Optional[str] = None,
58
    trust_remote_code: bool = False,
59
    processor_cls: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
60
    **kwargs: Any,
61
) -> _P:
62
    """Load a processor for the given model name via HuggingFace."""
63
64
65
    # don't put this import at the top level
    # it will call torch.cuda.device_count()
    from transformers import AutoProcessor
66

67
68
    processor_factory = (AutoProcessor if processor_cls == ProcessorMixin or
                         isinstance(processor_cls, tuple) else processor_cls)
69
70

    try:
71
        processor = processor_factory.from_pretrained(
72
73
            processor_name,
            *args,
74
            revision=revision,
75
            trust_remote_code=trust_remote_code,
76
77
            **kwargs,
        )
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    except ValueError as e:
        # If the error pertains to the processor class not existing or not
        # currently being imported, suggest using the --trust-remote-code flag.
        # Unlike AutoTokenizer, AutoProcessor does not separate such errors
        if not trust_remote_code:
            err_msg = (
                "Failed to load the processor. If the processor is "
                "a custom processor not yet available in the HuggingFace "
                "transformers library, consider setting "
                "`trust_remote_code=True` in LLM or using the "
                "`--trust-remote-code` flag in the CLI.")
            raise RuntimeError(err_msg) from e
        else:
            raise e

93
94
95
96
97
98
    if not isinstance(processor, processor_cls):
        raise TypeError("Invalid type of HuggingFace processor. "
                        f"Expected type: {processor_cls}, but "
                        f"found type: {type(processor)}")

    return processor
99
100


101
102
103
cached_get_processor = lru_cache(get_processor)


104
105
106
107
108
109
110
def cached_processor_from_config(
    model_config: "ModelConfig",
    processor_cls: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
    **kwargs: Any,
) -> _P:
    return cached_get_processor(
        model_config.model,
111
        revision=model_config.revision,
112
113
114
115
116
117
        trust_remote_code=model_config.trust_remote_code,
        processor_cls=processor_cls,  # type: ignore[arg-type]
        **_merge_mm_kwargs(model_config, **kwargs),
    )


118
119
120
def get_feature_extractor(
    processor_name: str,
    *args: Any,
121
    revision: Optional[str] = None,
122
123
124
125
126
127
128
129
130
131
132
133
134
    trust_remote_code: bool = False,
    **kwargs: Any,
):
    """Load an audio feature extractor for the given model name 
    via HuggingFace."""
    # don't put this import at the top level
    # it will call torch.cuda.device_count()
    from transformers import AutoFeatureExtractor
    from transformers.feature_extraction_utils import FeatureExtractionMixin
    try:
        feature_extractor = AutoFeatureExtractor.from_pretrained(
            processor_name,
            *args,
135
            revision=revision,
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
            trust_remote_code=trust_remote_code,
            **kwargs)
    except ValueError as e:
        # If the error pertains to the processor class not existing or not
        # currently being imported, suggest using the --trust-remote-code flag.
        # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors
        if not trust_remote_code:
            err_msg = (
                "Failed to load the feature extractor. If the feature "
                "extractor is a custom extractor not yet available in the "
                "HuggingFace transformers library, consider setting "
                "`trust_remote_code=True` in LLM or using the "
                "`--trust-remote-code` flag in the CLI.")
            raise RuntimeError(err_msg) from e
        else:
            raise e
    return cast(FeatureExtractionMixin, feature_extractor)


cached_get_feature_extractor = lru_cache(get_feature_extractor)


def cached_feature_extractor_from_config(
    model_config: "ModelConfig",
    **kwargs: Any,
):
    return cached_get_feature_extractor(
        model_config.model,
164
        revision=model_config.revision,
165
166
167
168
169
        trust_remote_code=model_config.trust_remote_code,
        **_merge_mm_kwargs(model_config, **kwargs),
    )


170
171
172
def get_image_processor(
    processor_name: str,
    *args: Any,
173
    revision: Optional[str] = None,
174
175
176
177
178
179
180
181
182
183
184
185
186
    trust_remote_code: bool = False,
    **kwargs: Any,
):
    """Load an image processor for the given model name via HuggingFace."""
    # don't put this import at the top level
    # it will call torch.cuda.device_count()
    from transformers import AutoImageProcessor
    from transformers.image_processing_utils import BaseImageProcessor

    try:
        processor = AutoImageProcessor.from_pretrained(
            processor_name,
            *args,
187
            revision=revision,
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
            trust_remote_code=trust_remote_code,
            **kwargs)
    except ValueError as e:
        # If the error pertains to the processor class not existing or not
        # currently being imported, suggest using the --trust-remote-code flag.
        # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors
        if not trust_remote_code:
            err_msg = (
                "Failed to load the image processor. If the image processor is "
                "a custom processor not yet available in the HuggingFace "
                "transformers library, consider setting "
                "`trust_remote_code=True` in LLM or using the "
                "`--trust-remote-code` flag in the CLI.")
            raise RuntimeError(err_msg) from e
        else:
            raise e

    return cast(BaseImageProcessor, processor)


208
209
210
211
212
213
214
215
216
cached_get_image_processor = lru_cache(get_image_processor)


def cached_image_processor_from_config(
    model_config: "ModelConfig",
    **kwargs: Any,
):
    return cached_get_image_processor(
        model_config.model,
217
        revision=model_config.revision,
218
219
220
        trust_remote_code=model_config.trust_remote_code,
        **_merge_mm_kwargs(model_config, **kwargs),
    )