processor.py 17.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import importlib
import inspect
6
from functools import lru_cache
7
from typing import TYPE_CHECKING, Any, cast, get_args, get_type_hints
8

9
10
11
12
13
from transformers import (
    AutoFeatureExtractor,
    AutoImageProcessor,
    AutoProcessor,
    AutoVideoProcessor,
14
    processing_utils,
15
)
16
17
from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.image_processing_utils import BaseImageProcessor
18
from transformers.processing_utils import ProcessorMixin
19
from transformers.video_processing_utils import BaseVideoProcessor
20
21
from typing_extensions import TypeVar

22
from vllm.logger import init_logger
23
24
from vllm.transformers_utils.gguf_utils import is_gguf
from vllm.transformers_utils.utils import convert_model_repo_to_path
25
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
26

27
28
logger = init_logger(__name__)

29
if TYPE_CHECKING:
30
    from vllm.config import ModelConfig
31

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

def _transformers_v4_compatibility_import():
    """Some remote code processors still import `ChatTemplateLoadKwargs` which was a
    subset of `ProcessorChatTemplateKwargs` as defined in Transformers v4.
    In Transformers v5 these were merged into `ProcessorChatTemplateKwargs` and
    `ChatTemplateLoadKwargs` was removed. For backward compatibility, we add an alias
    for `ChatTemplateLoadKwargs` if it doesn't exist.

    This can be removed if `HCXVisionForCausalLM` is upstreamed to Transformers."""
    old_import = getattr(processing_utils, "ChatTemplateLoadKwargs", None)
    new_import = getattr(processing_utils, "ProcessorChatTemplateKwargs", None)
    if old_import is None and new_import is not None:
        processing_utils.ChatTemplateLoadKwargs = new_import


47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def _transformers_v4_compatibility_init() -> Any:
    """Some remote code processors may define `optional_attributes` in their
    `ProcessorMixin` subclass, and then pass these arbitrary attributes directly to
    `ProcessorMixin.__init__`, which is no longer allowed in Transformers v5. For
    backward compatibility, we intercept these optional attributes and set them on the
    processor instance before calling the original `ProcessorMixin.__init__`.

    This can be removed if `Molmo2ForConditionalGeneration` is upstreamed to
    Transformers."""
    # Transformers v4
    if hasattr(ProcessorMixin, "optional_attributes"):
        return
    # Transformers v5
    if hasattr(ProcessorMixin.__init__, "_vllm_patched"):
        return

    original_init = ProcessorMixin.__init__

    def __init__(self, *args, **kwargs):
        for optional_attribute in getattr(self, "optional_attributes", []):
            if optional_attribute in kwargs:
                setattr(self, optional_attribute, kwargs.pop(optional_attribute))

        original_init(self, *args, **kwargs)

    # Only patch if ProcessorMixin is not mocked (for docs builds)
    if not hasattr(ProcessorMixin, "_mock_name"):
        __init__._vllm_patched = True  # type: ignore[attr-defined]
        ProcessorMixin.__init__ = __init__


78
_transformers_v4_compatibility_import()
79
_transformers_v4_compatibility_init()
80

81
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
82
_V = TypeVar("_V", bound=BaseVideoProcessor, default=BaseVideoProcessor)
83
84
85
86
87
88
89
90
91
92
93
94
95


class HashableDict(dict):
    """
    A dictionary that can be hashed by lru_cache.
    """

    # NOTE: pythonic dict is not hashable,
    # we override on it directly for simplicity
    def __hash__(self) -> int:  # type: ignore[override]
        return hash(frozenset(self.items()))


96
97
98
99
100
101
102
103
104
class HashableList(list):
    """
    A list that can be hashed by lru_cache.
    """

    def __hash__(self) -> int:  # type: ignore[override]
        return hash(tuple(self))


105
def _get_processor_factory_fn(processor_cls: type | tuple[type, ...]):
106
107
108
109
110
111
112
    if isinstance(processor_cls, tuple) or processor_cls == ProcessorMixin:
        return AutoProcessor.from_pretrained
    if hasattr(processor_cls, "from_pretrained"):
        return processor_cls.from_pretrained

    return processor_cls

113

114
115
def _merge_mm_kwargs(
    model_config: "ModelConfig",
116
    processor_cls: type | tuple[type, ...],
117
118
119
120
121
122
123
124
125
126
127
128
129
    /,
    **kwargs,
):
    mm_config = model_config.get_multimodal_config()
    merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs)

    factory = _get_processor_factory_fn(processor_cls)
    allowed_kwargs = get_allowed_kwarg_only_overrides(
        factory,
        merged_kwargs,
        requires_kw_only=False,
        allow_var_kwargs=True,
    )
130
131
132
    # NOTE: Pythonic dict is not hashable and will raise unhashable type
    # error when calling `cached_get_processor`, therefore we need to
    # wrap it to a hashable dict.
133
    for key, value in allowed_kwargs.items():
134
        if isinstance(value, dict):
135
            allowed_kwargs[key] = HashableDict(value)
136
        if isinstance(value, list):
137
138
139
            allowed_kwargs[key] = HashableList(value)

    return allowed_kwargs
140

141
142
143

def get_processor(
    processor_name: str,
144
    *args: Any,
145
    revision: str | None = None,
146
    trust_remote_code: bool = False,
147
    processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin,
148
    **kwargs: Any,
149
) -> _P:
150
    """Load a processor for the given model name via HuggingFace."""
151
152
    if revision is None:
        revision = "main"
153
    try:
154
        processor_name = convert_model_repo_to_path(processor_name)
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
        if isinstance(processor_cls, tuple) or processor_cls == ProcessorMixin:
            processor = AutoProcessor.from_pretrained(
                processor_name,
                *args,
                revision=revision,
                trust_remote_code=trust_remote_code,
                **kwargs,
            )
        elif issubclass(processor_cls, ProcessorMixin):
            processor = processor_cls.from_pretrained(
                processor_name,
                *args,
                revision=revision,
                trust_remote_code=trust_remote_code,
                **kwargs,
            )
        else:
            # Processors that are standalone classes unrelated to HF
            processor = processor_cls(*args, **kwargs)
174
175
176
177
178
179
180
181
182
183
    except ValueError as e:
        # If the error pertains to the processor class not existing or not
        # currently being imported, suggest using the --trust-remote-code flag.
        # Unlike AutoTokenizer, AutoProcessor does not separate such errors
        if not trust_remote_code:
            err_msg = (
                "Failed to load the processor. If the processor is "
                "a custom processor not yet available in the HuggingFace "
                "transformers library, consider setting "
                "`trust_remote_code=True` in LLM or using the "
184
185
                "`--trust-remote-code` flag in the CLI."
            )
186
187
188
189
            raise RuntimeError(err_msg) from e
        else:
            raise e

190
    if not isinstance(processor, processor_cls):
191
192
193
194
195
        raise TypeError(
            "Invalid type of HuggingFace processor. "
            f"Expected type: {processor_cls}, but "
            f"found type: {type(processor)}"
        )
196
197

    return processor
198
199


200
201
202
cached_get_processor = lru_cache(get_processor)


203
@lru_cache
204
205
206
def get_processor_kwargs_type(
    processor: ProcessorMixin,
) -> type[processing_utils.ProcessingKwargs]:
207
208
    try:
        # get kwargs annotations in processor
209
210
        call_params = inspect.signature(type(processor).__call__).parameters
        call_kwargs = call_params.get("kwargs")
211
        call_kwargs_annotations = call_kwargs.annotation if call_kwargs else None
212

213
214
215
216
217
218
        # if the processor has explicit kwargs annotation, use it
        if call_kwargs_annotations not in (None, inspect._empty):
            # get_type_hints will parse all type annotations at runtime,
            # and if an annotation refers to a type or
            # name that hasn’t been imported or defined, it will raise an error.
            # So we use __annotations__ to get the raw annotations directly.
219
220
221
222
223
224
225
226
227
            return get_args(call_kwargs_annotations)[0]

        # otherwise, try to get from ProcessorKwargs
        module_name = type(processor).__module__
        mod = importlib.import_module(module_name)
        for name, obj in vars(mod).items():
            if name.endswith("ProcessorKwargs"):
                return obj

228
    except Exception:
229
        logger.exception("Failed to collect processor kwargs")
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260

    return processing_utils.ProcessingKwargs


@lru_cache
def get_processor_kwargs_keys(
    kwargs_cls: type[processing_utils.ProcessingKwargs],
) -> set[str]:
    dynamic_kwargs: set[str] = set()
    modality_kwargs = {"text_kwargs", "images_kwargs", "videos_kwargs", "audio_kwargs"}

    try:
        # get kwargs annotations in processor
        # merge text_kwargs / images_kwargs / videos_kwargs / audio_kwargs
        kwargs_type_annotations = get_type_hints(kwargs_cls)
        for kw_type in modality_kwargs:
            if kw_type in kwargs_type_annotations:
                # Use __annotations__ instead of get_type_hints() to avoid
                # NameError from unresolved forward references (e.g.
                # PILImageResampling). We only need key names, not types.
                kw_cls = kwargs_type_annotations[kw_type]
                kw_annotations: dict[str, Any] = {}
                for base in reversed(kw_cls.__mro__):
                    kw_annotations.update(getattr(base, "__annotations__", {}))
                for kw_name in kw_annotations:
                    dynamic_kwargs.add(kw_name)

    except Exception:
        logger.exception("Failed to collect processor kwargs")

    return dynamic_kwargs | modality_kwargs
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279


def cached_get_processor_without_dynamic_kwargs(
    processor_name: str,
    *args: Any,
    revision: str | None = None,
    trust_remote_code: bool = False,
    processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin,
    **kwargs: Any,
) -> _P:
    # Step 1: use default kwargs to get a temporary processor instance
    processor = cached_get_processor(
        processor_name,
        revision=revision,
        trust_remote_code=trust_remote_code,
        processor_cls=processor_cls,  # type: ignore[arg-type]
    )

    # Step 2: use temporary processor collect dynamic keys
280
281
282
    dynamic_keys = get_processor_kwargs_keys(
        get_processor_kwargs_type(processor)  # type: ignore[arg-type]
    )
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298

    # Step 3: use dynamic_keys filter kwargs
    filtered_kwargs = {k: v for k, v in kwargs.items() if k not in dynamic_keys}

    # Step 4: use filtered kwargs to get final processor instance
    final_processor = cached_get_processor(
        processor_name,
        revision=revision,
        trust_remote_code=trust_remote_code,
        processor_cls=processor_cls,  # type: ignore[arg-type]
        **filtered_kwargs,
    )

    return final_processor


299
def cached_processor_from_config(
300
    model_config: "ModelConfig",
301
    processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin,
302
303
    **kwargs: Any,
) -> _P:
304
    if is_gguf(model_config.model):
305
        assert not is_gguf(model_config.tokenizer), (
306
307
308
            "For multimodal GGUF models, the original tokenizer "
            "should be used to correctly load processor."
        )
309
310
        model = model_config.tokenizer
        revision = model_config.tokenizer_revision
311
312
313
314
    else:
        model = model_config.model
        revision = model_config.revision

315
    return cached_get_processor_without_dynamic_kwargs(
316
317
        model,
        revision=revision,
318
319
        trust_remote_code=model_config.trust_remote_code,
        processor_cls=processor_cls,  # type: ignore[arg-type]
320
        **_merge_mm_kwargs(model_config, processor_cls, **kwargs),
321
322
323
    )


324
325
326
def get_feature_extractor(
    processor_name: str,
    *args: Any,
327
    revision: str | None = None,
328
329
330
    trust_remote_code: bool = False,
    **kwargs: Any,
):
331
    """Load an audio feature extractor for the given model name
332
333
    via HuggingFace."""
    try:
334
        processor_name = convert_model_repo_to_path(processor_name)
335
336
337
        feature_extractor = AutoFeatureExtractor.from_pretrained(
            processor_name,
            *args,
338
            revision=revision,
339
            trust_remote_code=trust_remote_code,
340
341
            **kwargs,
        )
342
343
344
345
346
347
348
349
350
351
    except ValueError as e:
        # If the error pertains to the processor class not existing or not
        # currently being imported, suggest using the --trust-remote-code flag.
        # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors
        if not trust_remote_code:
            err_msg = (
                "Failed to load the feature extractor. If the feature "
                "extractor is a custom extractor not yet available in the "
                "HuggingFace transformers library, consider setting "
                "`trust_remote_code=True` in LLM or using the "
352
353
                "`--trust-remote-code` flag in the CLI."
            )
354
355
356
357
358
359
360
361
362
363
            raise RuntimeError(err_msg) from e
        else:
            raise e
    return cast(FeatureExtractionMixin, feature_extractor)


cached_get_feature_extractor = lru_cache(get_feature_extractor)


def cached_feature_extractor_from_config(
364
    model_config: "ModelConfig",
365
366
367
368
    **kwargs: Any,
):
    return cached_get_feature_extractor(
        model_config.model,
369
        revision=model_config.revision,
370
        trust_remote_code=model_config.trust_remote_code,
371
        **_merge_mm_kwargs(model_config, AutoFeatureExtractor, **kwargs),
372
373
374
    )


375
376
377
def get_image_processor(
    processor_name: str,
    *args: Any,
378
    revision: str | None = None,
379
380
381
382
383
    trust_remote_code: bool = False,
    **kwargs: Any,
):
    """Load an image processor for the given model name via HuggingFace."""
    try:
384
        processor_name = convert_model_repo_to_path(processor_name)
385
386
387
        processor = AutoImageProcessor.from_pretrained(
            processor_name,
            *args,
388
            revision=revision,
389
            trust_remote_code=trust_remote_code,
390
391
            **kwargs,
        )
392
393
394
395
396
397
398
399
400
401
    except ValueError as e:
        # If the error pertains to the processor class not existing or not
        # currently being imported, suggest using the --trust-remote-code flag.
        # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors
        if not trust_remote_code:
            err_msg = (
                "Failed to load the image processor. If the image processor is "
                "a custom processor not yet available in the HuggingFace "
                "transformers library, consider setting "
                "`trust_remote_code=True` in LLM or using the "
402
403
                "`--trust-remote-code` flag in the CLI."
            )
404
405
406
407
408
409
410
            raise RuntimeError(err_msg) from e
        else:
            raise e

    return cast(BaseImageProcessor, processor)


411
412
413
414
cached_get_image_processor = lru_cache(get_image_processor)


def cached_image_processor_from_config(
415
    model_config: "ModelConfig",
416
417
    **kwargs: Any,
):
418
    if is_gguf(model_config.model):
419
        assert not is_gguf(model_config.tokenizer), (
420
421
422
            "For multimodal GGUF models, the original tokenizer "
            "should be used to correctly load image processor."
        )
423
424
        model = model_config.tokenizer
        revision = model_config.tokenizer_revision
425
426
427
    else:
        model = model_config.model
        revision = model_config.revision
428
    return cached_get_image_processor(
429
430
        model,
        revision=revision,
431
        trust_remote_code=model_config.trust_remote_code,
432
        **_merge_mm_kwargs(model_config, AutoImageProcessor, **kwargs),
433
    )
434
435
436
437
438


def get_video_processor(
    processor_name: str,
    *args: Any,
439
    revision: str | None = None,
440
    trust_remote_code: bool = False,
441
    processor_cls_overrides: type[_V] | None = None,
442
443
444
445
    **kwargs: Any,
):
    """Load a video processor for the given model name via HuggingFace."""
    try:
446
        processor_name = convert_model_repo_to_path(processor_name)
447
448
449
450
451
452
        processor_cls = processor_cls_overrides or AutoVideoProcessor
        processor = processor_cls.from_pretrained(
            processor_name,
            *args,
            revision=revision,
            trust_remote_code=trust_remote_code,
453
454
            **kwargs,
        )
455
456
457
458
459
460
461
462
463
464
    except ValueError as e:
        # If the error pertains to the processor class not existing or not
        # currently being imported, suggest using the --trust-remote-code flag.
        # Unlike AutoTokenizer, AutoVideoProcessor does not separate such errors
        if not trust_remote_code:
            err_msg = (
                "Failed to load the video processor. If the video processor is "
                "a custom processor not yet available in the HuggingFace "
                "transformers library, consider setting "
                "`trust_remote_code=True` in LLM or using the "
465
466
                "`--trust-remote-code` flag in the CLI."
            )
467
468
469
470
471
472
473
474
475
476
477
            raise RuntimeError(err_msg) from e
        else:
            raise e

    return cast(BaseVideoProcessor, processor)


cached_get_video_processor = lru_cache(get_video_processor)


def cached_video_processor_from_config(
478
    model_config: "ModelConfig",
479
    processor_cls: type[_V] | None = None,
480
481
482
483
484
485
486
487
488
    **kwargs: Any,
):
    return cached_get_video_processor(
        model_config.model,
        revision=model_config.revision,
        trust_remote_code=model_config.trust_remote_code,
        processor_cls_overrides=processor_cls,  # type: ignore[arg-type]
        **_merge_mm_kwargs(model_config, AutoVideoProcessor, **kwargs),
    )