processor.py 5.87 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from functools import lru_cache
4
from typing import TYPE_CHECKING, Any, Union, cast
5

6
from transformers.processing_utils import ProcessorMixin
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from typing_extensions import TypeVar

if TYPE_CHECKING:
    from vllm.config import ModelConfig

_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)


class HashableDict(dict):
    """
    A dictionary that can be hashed by lru_cache.
    """

    # NOTE: pythonic dict is not hashable,
    # we override on it directly for simplicity
    def __hash__(self) -> int:  # type: ignore[override]
        return hash(frozenset(self.items()))


def _merge_mm_kwargs(model_config: "ModelConfig", **kwargs):
    base_kwargs = model_config.mm_processor_kwargs
    if base_kwargs is None:
        base_kwargs = {}

    merged_kwargs = {**base_kwargs, **kwargs}

    # NOTE: Pythonic dict is not hashable and will raise unhashable type
    # error when calling `cached_get_processor`, therefore we need to
    # wrap it to a hashable dict.
    for key, value in merged_kwargs.items():
        if isinstance(value, dict):
            merged_kwargs[key] = HashableDict(value)

    return merged_kwargs
41

42
43
44

def get_processor(
    processor_name: str,
45
    *args: Any,
46
    trust_remote_code: bool = False,
47
    processor_cls: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
48
    **kwargs: Any,
49
) -> _P:
50
    """Load a processor for the given model name via HuggingFace."""
51
52
53
    # don't put this import at the top level
    # it will call torch.cuda.device_count()
    from transformers import AutoProcessor
54

55
56
    processor_factory = (AutoProcessor if processor_cls == ProcessorMixin or
                         isinstance(processor_cls, tuple) else processor_cls)
57
58

    try:
59
        processor = processor_factory.from_pretrained(
60
61
62
            processor_name,
            *args,
            trust_remote_code=trust_remote_code,
63
64
            **kwargs,
        )
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    except ValueError as e:
        # If the error pertains to the processor class not existing or not
        # currently being imported, suggest using the --trust-remote-code flag.
        # Unlike AutoTokenizer, AutoProcessor does not separate such errors
        if not trust_remote_code:
            err_msg = (
                "Failed to load the processor. If the processor is "
                "a custom processor not yet available in the HuggingFace "
                "transformers library, consider setting "
                "`trust_remote_code=True` in LLM or using the "
                "`--trust-remote-code` flag in the CLI.")
            raise RuntimeError(err_msg) from e
        else:
            raise e

80
81
82
83
84
85
    if not isinstance(processor, processor_cls):
        raise TypeError("Invalid type of HuggingFace processor. "
                        f"Expected type: {processor_cls}, but "
                        f"found type: {type(processor)}")

    return processor
86
87


88
89
90
cached_get_processor = lru_cache(get_processor)


91
92
93
94
95
96
97
98
99
100
101
102
103
def cached_processor_from_config(
    model_config: "ModelConfig",
    processor_cls: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
    **kwargs: Any,
) -> _P:
    return cached_get_processor(
        model_config.model,
        trust_remote_code=model_config.trust_remote_code,
        processor_cls=processor_cls,  # type: ignore[arg-type]
        **_merge_mm_kwargs(model_config, **kwargs),
    )


104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def get_image_processor(
    processor_name: str,
    *args: Any,
    trust_remote_code: bool = False,
    **kwargs: Any,
):
    """Load an image processor for the given model name via HuggingFace."""
    # don't put this import at the top level
    # it will call torch.cuda.device_count()
    from transformers import AutoImageProcessor
    from transformers.image_processing_utils import BaseImageProcessor

    try:
        processor = AutoImageProcessor.from_pretrained(
            processor_name,
            *args,
            trust_remote_code=trust_remote_code,
            **kwargs)
    except ValueError as e:
        # If the error pertains to the processor class not existing or not
        # currently being imported, suggest using the --trust-remote-code flag.
        # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors
        if not trust_remote_code:
            err_msg = (
                "Failed to load the image processor. If the image processor is "
                "a custom processor not yet available in the HuggingFace "
                "transformers library, consider setting "
                "`trust_remote_code=True` in LLM or using the "
                "`--trust-remote-code` flag in the CLI.")
            raise RuntimeError(err_msg) from e
        else:
            raise e

    return cast(BaseImageProcessor, processor)


140
141
142
143
144
145
146
147
148
149
150
151
152
153
cached_get_image_processor = lru_cache(get_image_processor)


def cached_image_processor_from_config(
    model_config: "ModelConfig",
    **kwargs: Any,
):
    return cached_get_image_processor(
        model_config.model,
        trust_remote_code=model_config.trust_remote_code,
        **_merge_mm_kwargs(model_config, **kwargs),
    )


154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def get_video_processor(
    processor_name: str,
    *args: Any,
    trust_remote_code: bool = False,
    **kwargs: Any,
):
    """Load a video processor for the given model name via HuggingFace."""
    # don't put this import at the top level
    # it will call torch.cuda.device_count()
    from transformers.image_processing_utils import BaseImageProcessor

    processor = get_processor(
        processor_name,
        *args,
        trust_remote_code=trust_remote_code,
        **kwargs,
    )

    return cast(BaseImageProcessor, processor.video_processor)
173
174
175
176
177
178
179
180
181
182
183
184
185
186


cached_get_video_processor = lru_cache(get_video_processor)


def cached_video_processor_from_config(
    model_config: "ModelConfig",
    **kwargs: Any,
):
    return cached_get_video_processor(
        model_config.model,
        trust_remote_code=model_config.trust_remote_code,
        **_merge_mm_kwargs(model_config, **kwargs),
    )