processor.py 19.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import importlib
import inspect
6
from functools import lru_cache
7
from typing import TYPE_CHECKING, Any, cast, get_args, get_type_hints
8

9
10
11
12
13
from transformers import (
    AutoFeatureExtractor,
    AutoImageProcessor,
    AutoProcessor,
    AutoVideoProcessor,
14
    BatchFeature,
15
    processing_utils,
16
)
17
from transformers.audio_utils import AudioInput
18
19
from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.image_processing_utils import BaseImageProcessor
20
from transformers.image_utils import ImageInput
21
from transformers.processing_utils import ProcessorMixin
22
from transformers.video_processing_utils import BaseVideoProcessor
23
from transformers.video_utils import VideoInput
24
25
from typing_extensions import TypeVar

26
from vllm.logger import init_logger
27
from vllm.transformers_utils import processors
28
from vllm.transformers_utils.gguf_utils import is_gguf
29
from vllm.transformers_utils.repo_utils import get_hf_file_to_dict
30
from vllm.transformers_utils.utils import convert_model_repo_to_path
31
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
32

33
34
logger = init_logger(__name__)

35
if TYPE_CHECKING:
36
    from vllm.config import ModelConfig
37

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

def _transformers_v4_compatibility_import():
    """Some remote code processors still import `ChatTemplateLoadKwargs` which was a
    subset of `ProcessorChatTemplateKwargs` as defined in Transformers v4.
    In Transformers v5 these were merged into `ProcessorChatTemplateKwargs` and
    `ChatTemplateLoadKwargs` was removed. For backward compatibility, we add an alias
    for `ChatTemplateLoadKwargs` if it doesn't exist.

    This can be removed if `HCXVisionForCausalLM` is upstreamed to Transformers."""
    old_import = getattr(processing_utils, "ChatTemplateLoadKwargs", None)
    new_import = getattr(processing_utils, "ProcessorChatTemplateKwargs", None)
    if old_import is None and new_import is not None:
        processing_utils.ChatTemplateLoadKwargs = new_import


53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def _transformers_v4_compatibility_init() -> Any:
    """Some remote code processors may define `optional_attributes` in their
    `ProcessorMixin` subclass, and then pass these arbitrary attributes directly to
    `ProcessorMixin.__init__`, which is no longer allowed in Transformers v5. For
    backward compatibility, we intercept these optional attributes and set them on the
    processor instance before calling the original `ProcessorMixin.__init__`.

    This can be removed if `Molmo2ForConditionalGeneration` is upstreamed to
    Transformers."""
    # Transformers v4
    if hasattr(ProcessorMixin, "optional_attributes"):
        return
    # Transformers v5
    if hasattr(ProcessorMixin.__init__, "_vllm_patched"):
        return

    original_init = ProcessorMixin.__init__

    def __init__(self, *args, **kwargs):
        for optional_attribute in getattr(self, "optional_attributes", []):
            if optional_attribute in kwargs:
                setattr(self, optional_attribute, kwargs.pop(optional_attribute))

        original_init(self, *args, **kwargs)

    # Only patch if ProcessorMixin is not mocked (for docs builds)
    if not hasattr(ProcessorMixin, "_mock_name"):
        __init__._vllm_patched = True  # type: ignore[attr-defined]
        ProcessorMixin.__init__ = __init__


84
_transformers_v4_compatibility_import()
85
_transformers_v4_compatibility_init()
86

87
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
88
_V = TypeVar("_V", bound=BaseVideoProcessor, default=BaseVideoProcessor)
89
90
91
92
93
94
95
96
97
98
99
100
101


class HashableDict(dict):
    """
    A dictionary that can be hashed by lru_cache.
    """

    # NOTE: pythonic dict is not hashable,
    # we override on it directly for simplicity
    def __hash__(self) -> int:  # type: ignore[override]
        return hash(frozenset(self.items()))


102
103
104
105
106
107
108
109
110
class HashableList(list):
    """
    A list that can be hashed by lru_cache.
    """

    def __hash__(self) -> int:  # type: ignore[override]
        return hash(tuple(self))


111
def _get_processor_factory_fn(processor_cls: type | tuple[type, ...]):
112
113
114
115
116
117
118
    if isinstance(processor_cls, tuple) or processor_cls == ProcessorMixin:
        return AutoProcessor.from_pretrained
    if hasattr(processor_cls, "from_pretrained"):
        return processor_cls.from_pretrained

    return processor_cls

119

120
121
def _merge_mm_kwargs(
    model_config: "ModelConfig",
122
    processor_cls: type | tuple[type, ...],
123
124
125
126
127
128
129
130
131
132
133
134
135
    /,
    **kwargs,
):
    mm_config = model_config.get_multimodal_config()
    merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs)

    factory = _get_processor_factory_fn(processor_cls)
    allowed_kwargs = get_allowed_kwarg_only_overrides(
        factory,
        merged_kwargs,
        requires_kw_only=False,
        allow_var_kwargs=True,
    )
136
137
138
    # NOTE: Pythonic dict is not hashable and will raise unhashable type
    # error when calling `cached_get_processor`, therefore we need to
    # wrap it to a hashable dict.
139
    for key, value in allowed_kwargs.items():
140
        if isinstance(value, dict):
141
            allowed_kwargs[key] = HashableDict(value)
142
        if isinstance(value, list):
143
144
145
            allowed_kwargs[key] = HashableList(value)

    return allowed_kwargs
146

147

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def get_processor_cls_name_from_config(
    processor_name: str,
    revision: str | None = "main",
) -> str | None:
    config_file = [
        "processor_config.json",
        "preprocessor_config.json",
        "tokenizer_config.json",
    ]
    for file in config_file:
        config = get_hf_file_to_dict(file, processor_name, revision=revision)
        if config and "processor_class" in config:
            return config["processor_class"]
    return None


164
165
def get_processor(
    processor_name: str,
166
    *args: Any,
167
    revision: str | None = None,
168
    trust_remote_code: bool = False,
169
    processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin,
170
    **kwargs: Any,
171
) -> _P:
172
    """Load a processor for the given model name via HuggingFace."""
173
174
    if revision is None:
        revision = "main"
175
    try:
176
        processor_name = convert_model_repo_to_path(processor_name)
177
178
179
180
181
182
183
184
185
186
187
        registered_cls_name = get_processor_cls_name_from_config(
            processor_name, revision=revision
        )
        registered_processor_cls = (
            getattr(processors, registered_cls_name, None)
            if registered_cls_name
            else None
        )
        registered_processor_cls = cast(type[_P] | None, registered_processor_cls)
        # Use registered processor class when it's available
        # and explicit processor_cls is not set.
188
        if isinstance(processor_cls, tuple) or processor_cls == ProcessorMixin:
189
190
            _processor_cls = registered_processor_cls or AutoProcessor
            processor = _processor_cls.from_pretrained(
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
                processor_name,
                *args,
                revision=revision,
                trust_remote_code=trust_remote_code,
                **kwargs,
            )
        elif issubclass(processor_cls, ProcessorMixin):
            processor = processor_cls.from_pretrained(
                processor_name,
                *args,
                revision=revision,
                trust_remote_code=trust_remote_code,
                **kwargs,
            )
        else:
            # Processors that are standalone classes unrelated to HF
            processor = processor_cls(*args, **kwargs)
208
209
210
211
212
213
214
215
216
217
    except ValueError as e:
        # If the error pertains to the processor class not existing or not
        # currently being imported, suggest using the --trust-remote-code flag.
        # Unlike AutoTokenizer, AutoProcessor does not separate such errors
        if not trust_remote_code:
            err_msg = (
                "Failed to load the processor. If the processor is "
                "a custom processor not yet available in the HuggingFace "
                "transformers library, consider setting "
                "`trust_remote_code=True` in LLM or using the "
218
219
                "`--trust-remote-code` flag in the CLI."
            )
220
221
222
223
            raise RuntimeError(err_msg) from e
        else:
            raise e

224
    if not isinstance(processor, processor_cls):
225
226
227
228
229
        raise TypeError(
            "Invalid type of HuggingFace processor. "
            f"Expected type: {processor_cls}, but "
            f"found type: {type(processor)}"
        )
230
231

    return processor
232
233


234
235
236
cached_get_processor = lru_cache(get_processor)


237
@lru_cache
238
239
240
def get_processor_kwargs_type(
    processor: ProcessorMixin,
) -> type[processing_utils.ProcessingKwargs]:
241
242
    try:
        # get kwargs annotations in processor
243
244
        call_params = inspect.signature(type(processor).__call__).parameters
        call_kwargs = call_params.get("kwargs")
245
        call_kwargs_annotations = call_kwargs.annotation if call_kwargs else None
246

247
        # if the processor has explicit kwargs annotation, use it
248
        if call_kwargs_annotations not in (None, inspect._empty):  # noqa: SIM102
249
250
251
252
            # get_type_hints will parse all type annotations at runtime,
            # and if an annotation refers to a type or
            # name that hasn’t been imported or defined, it will raise an error.
            # So we use __annotations__ to get the raw annotations directly.
253
254
            if anno_args := get_args(call_kwargs_annotations):
                return anno_args[0]
255
256
257
258
259
260
261
262

        # otherwise, try to get from ProcessorKwargs
        module_name = type(processor).__module__
        mod = importlib.import_module(module_name)
        for name, obj in vars(mod).items():
            if name.endswith("ProcessorKwargs"):
                return obj

263
    except Exception:
264
        logger.exception("Failed to collect processor kwargs")
265
266
267
268
269
270
271
272
273

    return processing_utils.ProcessingKwargs


@lru_cache
def get_processor_kwargs_keys(
    kwargs_cls: type[processing_utils.ProcessingKwargs],
) -> set[str]:
    dynamic_kwargs: set[str] = set()
274
275
276
277
278
279
    modality_kwargs = {
        "text_kwargs",
        "images_kwargs",
        "videos_kwargs",
        "audio_kwargs",
    }
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300

    try:
        # get kwargs annotations in processor
        # merge text_kwargs / images_kwargs / videos_kwargs / audio_kwargs
        kwargs_type_annotations = get_type_hints(kwargs_cls)
        for kw_type in modality_kwargs:
            if kw_type in kwargs_type_annotations:
                # Use __annotations__ instead of get_type_hints() to avoid
                # NameError from unresolved forward references (e.g.
                # PILImageResampling). We only need key names, not types.
                kw_cls = kwargs_type_annotations[kw_type]
                kw_annotations: dict[str, Any] = {}
                for base in reversed(kw_cls.__mro__):
                    kw_annotations.update(getattr(base, "__annotations__", {}))
                for kw_name in kw_annotations:
                    dynamic_kwargs.add(kw_name)

    except Exception:
        logger.exception("Failed to collect processor kwargs")

    return dynamic_kwargs | modality_kwargs
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319


def cached_get_processor_without_dynamic_kwargs(
    processor_name: str,
    *args: Any,
    revision: str | None = None,
    trust_remote_code: bool = False,
    processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin,
    **kwargs: Any,
) -> _P:
    # Step 1: use default kwargs to get a temporary processor instance
    processor = cached_get_processor(
        processor_name,
        revision=revision,
        trust_remote_code=trust_remote_code,
        processor_cls=processor_cls,  # type: ignore[arg-type]
    )

    # Step 2: use temporary processor collect dynamic keys
320
321
322
    dynamic_keys = get_processor_kwargs_keys(
        get_processor_kwargs_type(processor)  # type: ignore[arg-type]
    )
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338

    # Step 3: use dynamic_keys filter kwargs
    filtered_kwargs = {k: v for k, v in kwargs.items() if k not in dynamic_keys}

    # Step 4: use filtered kwargs to get final processor instance
    final_processor = cached_get_processor(
        processor_name,
        revision=revision,
        trust_remote_code=trust_remote_code,
        processor_cls=processor_cls,  # type: ignore[arg-type]
        **filtered_kwargs,
    )

    return final_processor


339
def cached_processor_from_config(
340
    model_config: "ModelConfig",
341
    processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin,
342
343
    **kwargs: Any,
) -> _P:
344
    if is_gguf(model_config.model):
345
        assert not is_gguf(model_config.tokenizer), (
346
347
348
            "For multimodal GGUF models, the original tokenizer "
            "should be used to correctly load processor."
        )
349
350
        model = model_config.tokenizer
        revision = model_config.tokenizer_revision
351
352
353
354
    else:
        model = model_config.model
        revision = model_config.revision

355
    return cached_get_processor_without_dynamic_kwargs(
356
357
        model,
        revision=revision,
358
359
        trust_remote_code=model_config.trust_remote_code,
        processor_cls=processor_cls,  # type: ignore[arg-type]
360
        **_merge_mm_kwargs(model_config, processor_cls, **kwargs),
361
362
363
    )


364
365
366
def get_feature_extractor(
    processor_name: str,
    *args: Any,
367
    revision: str | None = None,
368
369
370
    trust_remote_code: bool = False,
    **kwargs: Any,
):
371
    """Load an audio feature extractor for the given model name
372
373
    via HuggingFace."""
    try:
374
        processor_name = convert_model_repo_to_path(processor_name)
375
376
377
        feature_extractor = AutoFeatureExtractor.from_pretrained(
            processor_name,
            *args,
378
            revision=revision,
379
            trust_remote_code=trust_remote_code,
380
381
            **kwargs,
        )
382
383
384
385
386
387
388
389
390
391
    except ValueError as e:
        # If the error pertains to the processor class not existing or not
        # currently being imported, suggest using the --trust-remote-code flag.
        # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors
        if not trust_remote_code:
            err_msg = (
                "Failed to load the feature extractor. If the feature "
                "extractor is a custom extractor not yet available in the "
                "HuggingFace transformers library, consider setting "
                "`trust_remote_code=True` in LLM or using the "
392
393
                "`--trust-remote-code` flag in the CLI."
            )
394
395
396
397
398
399
400
401
402
403
            raise RuntimeError(err_msg) from e
        else:
            raise e
    return cast(FeatureExtractionMixin, feature_extractor)


cached_get_feature_extractor = lru_cache(get_feature_extractor)


def cached_feature_extractor_from_config(
404
    model_config: "ModelConfig",
405
406
407
408
    **kwargs: Any,
):
    return cached_get_feature_extractor(
        model_config.model,
409
        revision=model_config.revision,
410
        trust_remote_code=model_config.trust_remote_code,
411
        **_merge_mm_kwargs(model_config, AutoFeatureExtractor, **kwargs),
412
413
414
    )


415
416
417
def get_image_processor(
    processor_name: str,
    *args: Any,
418
    revision: str | None = None,
419
420
421
422
423
    trust_remote_code: bool = False,
    **kwargs: Any,
):
    """Load an image processor for the given model name via HuggingFace."""
    try:
424
        processor_name = convert_model_repo_to_path(processor_name)
425
426
427
        processor = AutoImageProcessor.from_pretrained(
            processor_name,
            *args,
428
            revision=revision,
429
            trust_remote_code=trust_remote_code,
430
431
            **kwargs,
        )
432
433
434
435
436
437
438
439
440
441
    except ValueError as e:
        # If the error pertains to the processor class not existing or not
        # currently being imported, suggest using the --trust-remote-code flag.
        # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors
        if not trust_remote_code:
            err_msg = (
                "Failed to load the image processor. If the image processor is "
                "a custom processor not yet available in the HuggingFace "
                "transformers library, consider setting "
                "`trust_remote_code=True` in LLM or using the "
442
443
                "`--trust-remote-code` flag in the CLI."
            )
444
445
446
447
448
449
450
            raise RuntimeError(err_msg) from e
        else:
            raise e

    return cast(BaseImageProcessor, processor)


451
452
453
454
cached_get_image_processor = lru_cache(get_image_processor)


def cached_image_processor_from_config(
455
    model_config: "ModelConfig",
456
457
    **kwargs: Any,
):
458
    if is_gguf(model_config.model):
459
        assert not is_gguf(model_config.tokenizer), (
460
461
462
            "For multimodal GGUF models, the original tokenizer "
            "should be used to correctly load image processor."
        )
463
464
        model = model_config.tokenizer
        revision = model_config.tokenizer_revision
465
466
467
    else:
        model = model_config.model
        revision = model_config.revision
468
    return cached_get_image_processor(
469
470
        model,
        revision=revision,
471
        trust_remote_code=model_config.trust_remote_code,
472
        **_merge_mm_kwargs(model_config, AutoImageProcessor, **kwargs),
473
    )
474
475
476
477
478


def get_video_processor(
    processor_name: str,
    *args: Any,
479
    revision: str | None = None,
480
    trust_remote_code: bool = False,
481
    processor_cls_overrides: type[_V] | None = None,
482
483
484
485
    **kwargs: Any,
):
    """Load a video processor for the given model name via HuggingFace."""
    try:
486
        processor_name = convert_model_repo_to_path(processor_name)
487
488
489
490
491
492
        processor_cls = processor_cls_overrides or AutoVideoProcessor
        processor = processor_cls.from_pretrained(
            processor_name,
            *args,
            revision=revision,
            trust_remote_code=trust_remote_code,
493
494
            **kwargs,
        )
495
496
497
498
499
500
501
502
503
504
    except ValueError as e:
        # If the error pertains to the processor class not existing or not
        # currently being imported, suggest using the --trust-remote-code flag.
        # Unlike AutoTokenizer, AutoVideoProcessor does not separate such errors
        if not trust_remote_code:
            err_msg = (
                "Failed to load the video processor. If the video processor is "
                "a custom processor not yet available in the HuggingFace "
                "transformers library, consider setting "
                "`trust_remote_code=True` in LLM or using the "
505
506
                "`--trust-remote-code` flag in the CLI."
            )
507
508
509
510
511
512
513
514
515
516
517
            raise RuntimeError(err_msg) from e
        else:
            raise e

    return cast(BaseVideoProcessor, processor)


cached_get_video_processor = lru_cache(get_video_processor)


def cached_video_processor_from_config(
518
    model_config: "ModelConfig",
519
    processor_cls: type[_V] | None = None,
520
521
522
523
524
525
526
527
528
    **kwargs: Any,
):
    return cached_get_video_processor(
        model_config.model,
        revision=model_config.revision,
        trust_remote_code=model_config.trust_remote_code,
        processor_cls_overrides=processor_cls,  # type: ignore[arg-type]
        **_merge_mm_kwargs(model_config, AutoVideoProcessor, **kwargs),
    )
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568


def call_hf_processor_mm_only(
    processor: ProcessorMixin,
    images: ImageInput | None = None,
    videos: VideoInput | None = None,
    audio: AudioInput | None = None,
    **kwargs,
) -> BatchFeature:
    output_kwargs = processor._merge_kwargs(
        get_processor_kwargs_type(processor),
        **kwargs,
    )

    if audio is not None and (
        feature_extractor := getattr(processor, "feature_extractor", None)
    ):
        audio_inputs = feature_extractor(audio, **output_kwargs["audio_kwargs"])
        audio_inputs["feature_attention_mask"] = audio_inputs.pop("attention_mask")
    else:
        audio_inputs = {}

    if images is not None and (
        image_processor := getattr(processor, "image_processor", None)
    ):
        images_inputs = image_processor(images=images, **output_kwargs["images_kwargs"])
    else:
        images_inputs = {}

    if videos is not None and (
        video_processor := getattr(processor, "video_processor", None)
    ):
        videos_inputs = video_processor(videos=videos, **output_kwargs["videos_kwargs"])
    else:
        videos_inputs = {}

    return BatchFeature(
        data={**audio_inputs, **images_inputs, **videos_inputs},
        tensor_type=kwargs.get("return_tensors"),
    )