Unverified Commit 82de9b9d authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Automatically resolve HF processor init kwargs (#22005)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent ad57f23f
......@@ -331,10 +331,8 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return hf_processor
def get_image_processor(self):
hf_processor = self.get_hf_processor()
image_processor = hf_processor.image_processor # type: ignore
return image_processor
def get_image_processor(self, **kwargs: object):
return self.get_hf_processor(**kwargs).image_processor
def get_model_version(self):
return get_version_by_config(self.get_hf_config())
......
......@@ -533,7 +533,7 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self, **kwargs: object) -> Llama4Processor:
return self.ctx.get_hf_processor(Llama4Processor,
use_fast=True,
use_fast=kwargs.pop("use_fast", True),
**kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
......
......@@ -137,34 +137,16 @@ class NemotronVLProcessor(InternVLProcessor):
class NemotronVLProcessingInfo(BaseInternVLProcessingInfo):
"""Processing info for Nemotron VL models."""
def get_hf_processor(
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
**kwargs: object,
) -> NemotronVLProcessor:
if min_dynamic_patch is not None:
kwargs["min_dynamic_patch"] = min_dynamic_patch
if max_dynamic_patch is not None:
kwargs["max_dynamic_patch"] = max_dynamic_patch
if dynamic_image_size is not None:
kwargs["dynamic_image_size"] = dynamic_image_size
image_processor = self.get_image_processor()
def get_hf_processor(self, **kwargs: object) -> NemotronVLProcessor:
return self.ctx.init_processor(
NemotronVLProcessor,
config=self.get_hf_config(),
tokenizer=self.get_tokenizer(),
image_processor=image_processor,
image_processor=self.get_image_processor(),
**kwargs,
)
def get_image_processor(
self,
**kwargs: object,
):
def get_image_processor(self, **kwargs: object):
return cached_image_processor_from_config(
self.ctx.model_config,
**kwargs,
......
......@@ -63,21 +63,7 @@ class NVLMProcessor(BaseInternVLProcessor):
class NVLMProcessingInfo(BaseInternVLProcessingInfo):
def get_hf_processor(
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
**kwargs: object,
) -> NVLMProcessor:
if min_dynamic_patch is not None:
kwargs["min_dynamic_patch"] = min_dynamic_patch
if max_dynamic_patch is not None:
kwargs["max_dynamic_patch"] = max_dynamic_patch
if dynamic_image_size is not None:
kwargs["dynamic_image_size"] = dynamic_image_size
def get_hf_processor(self, **kwargs: object) -> NVLMProcessor:
return self.ctx.init_processor(
NVLMProcessor,
config=self.get_hf_config(),
......
......@@ -25,7 +25,7 @@ import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.functional import gumbel_softmax, pad, softmax
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig
from transformers import BatchFeature, PretrainedConfig
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ReplicatedLinear
......@@ -245,11 +245,12 @@ class VisualEmbedding(torch.nn.Embedding):
class OvisProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self, **kwargs):
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(
OvisProcessor,
image_pad_token=self.get_image_pad_token(),
image_segment_len=self.get_image_segment_len(),
**kwargs,
)
def get_image_segment_len(self) -> int:
......@@ -269,9 +270,6 @@ class OvisProcessingInfo(BaseProcessingInfo):
text_model_type = hf_text_config.model_type
return IMAGE_PAD_TOKEN_MAP.get(text_model_type)
def get_image_processor(self) -> BaseImageProcessor:
return self.get_hf_processor().image_processor # type: ignore
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
......
......@@ -318,17 +318,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
class Phi3VProcessingInfo(BaseProcessingInfo):
def get_hf_processor(
self,
*,
num_crops: Optional[int] = None,
**kwargs: object,
) -> ProcessorMixin:
if num_crops is not None:
kwargs["num_crops"] = num_crops
return self.ctx.get_hf_processor(**kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
......
......@@ -696,19 +696,12 @@ class Phi4MMProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> Phi4MultimodalConfig:
return self.ctx.get_hf_config(Phi4MultimodalConfig)
def get_hf_processor(
self,
*,
dynamic_hd: Optional[int] = None,
**kwargs: object,
) -> Phi4MMProcessor:
if dynamic_hd is not None:
kwargs["dynamic_hd"] = dynamic_hd
return self.ctx.get_hf_processor(**kwargs)
def get_hf_processor(self, **kwargs: object) -> Phi4MMProcessor:
return self.ctx.get_hf_processor(Phi4MMProcessor, **kwargs)
def get_feature_extractor(self) -> Phi4MultimodalFeatureExtractor:
return self.get_hf_processor().audio_processor
def get_feature_extractor(
self, **kwargs: object) -> Phi4MultimodalFeatureExtractor:
return self.get_hf_processor(**kwargs).audio_processor
def get_image_processor(
self,
......@@ -1007,7 +1000,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
if audio_data:
audio_features = processed_outputs['audio_input_features']
sr = self.info.get_feature_extractor().sampling_rate
sr = self.info.get_feature_extractor(**mm_kwargs).sampling_rate
feature_sizes = [
self.info.get_audio_num_frames(len(audio), sr)
for audio in audio_data
......@@ -1043,7 +1036,8 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
audio_token_id = tokenizer.vocab[tokenizer.audio_token]
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
audio_processor = self.info.get_feature_extractor()
audio_processor = self.info.get_feature_extractor(
**hf_processor_mm_kwargs)
def get_image_replacement_phi4mm(item_idx: int):
images = mm_items.get_items(
......
......@@ -459,17 +459,6 @@ def cat_with_pad(tensors, dim, padding_value=0):
class Phi4MMProcessingInfo(BaseProcessingInfo):
def get_hf_processor(
self,
*,
dynamic_hd: Optional[int] = None,
**kwargs: object,
) -> ProcessorMixin:
if dynamic_hd is not None:
kwargs["dynamic_hd"] = dynamic_hd
return self.ctx.get_hf_processor(**kwargs)
@property
def image_tokens(self) -> list[str]:
return [f"<|image_{i+1}|>" for i in range(100)]
......@@ -487,8 +476,9 @@ class Phi4MMProcessingInfo(BaseProcessingInfo):
image_processor = processor.image_processor
return image_processor.dynamic_hd
def get_feature_extractor(self) -> SequenceFeatureExtractor:
return self.get_hf_processor().audio_processor
def get_feature_extractor(self,
**kwargs: object) -> SequenceFeatureExtractor:
return self.get_hf_processor(**kwargs).audio_processor
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None, "image": None}
......@@ -769,7 +759,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
sr = self.info.get_feature_extractor().sampling_rate
sr = self.info.get_feature_extractor(**mm_kwargs).sampling_rate
if (audio_data := mm_data.get("audios", [])):
mm_data['audios'] = [(data, sr) for data in audio_data]
......@@ -816,7 +806,8 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
) -> Sequence[PromptUpdate]:
image_tokens: list[str] = self.info.image_tokens # type: ignore
audio_tokens: list[str] = self.info.audio_tokens # type: ignore
feature_extractor = self.info.get_feature_extractor()
feature_extractor = self.info.get_feature_extractor(
**hf_processor_mm_kwargs)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
def get_image_replacement_phi4mm(item_idx: int):
......
......@@ -132,50 +132,15 @@ class Qwen2_5OmniThinkerProcessingInfo(Qwen2AudioProcessingInfo,
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2_5OmniConfig).thinker_config
def get_hf_processor(
self,
*,
sampling_rate: Optional[int] = None,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None,
fps: Optional[Union[float, list[float]]] = None,
**kwargs: object,
) -> Qwen2_5OmniProcessor:
if fps is not None:
kwargs["fps"] = fps
# Monkey patch for Transformers v4.53
processor_class = Qwen2_5OmniProcessor
if processor_class.image_processor_class != "AutoImageProcessor":
processor_class.image_processor_class = "AutoImageProcessor"
if processor_class.video_processor_class != "AutoVideoProcessor":
processor_class.video_processor_class = "AutoVideoProcessor"
processor = self.ctx.get_hf_processor(
processor_class,
image_processor=self.get_image_processor(min_pixels=min_pixels,
max_pixels=max_pixels,
size=size,
use_fast=kwargs.get(
"use_fast", True)),
def get_hf_processor(self, **kwargs: object) -> Qwen2_5OmniProcessor:
return self.ctx.get_hf_processor(
Qwen2_5OmniProcessor,
use_fast=kwargs.pop("use_fast", True),
**kwargs,
)
if not hasattr(processor, "audio_token"):
processor.audio_token = "<|AUDIO|>"
if not hasattr(processor, "image_token"):
processor.image_token = "<|IMAGE|>"
if not hasattr(processor, "video_token"):
processor.video_token = "<|VIDEO|>"
return processor
def get_feature_extractor(
self,
*,
sampling_rate: Optional[int] = None,
**kwargs: object,
):
hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
def get_feature_extractor(self, **kwargs: object):
hf_processor = self.get_hf_processor(**kwargs)
feature_extractor = hf_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
......
......@@ -780,25 +780,10 @@ class Qwen2_5_VLProcessingInfo(Qwen2VLProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2_5_VLConfig)
def get_hf_processor(
self,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None,
fps: Optional[Union[float, list[float]]] = None,
**kwargs: object,
) -> Qwen2_5_VLProcessor:
if fps is not None:
kwargs["fps"] = fps
def get_hf_processor(self, **kwargs: object) -> Qwen2_5_VLProcessor:
return self.ctx.get_hf_processor(
Qwen2_5_VLProcessor,
image_processor=self.get_image_processor(min_pixels=min_pixels,
max_pixels=max_pixels,
size=size,
use_fast=kwargs.get(
"use_fast", True)),
use_fast=kwargs.pop("use_fast", True),
**kwargs,
)
......
......@@ -86,22 +86,12 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2AudioConfig)
def get_hf_processor(
self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
**kwargs: object,
) -> Qwen2AudioProcessor:
def get_hf_processor(self, **kwargs: object) -> Qwen2AudioProcessor:
return self.ctx.get_hf_processor(Qwen2AudioProcessor, **kwargs)
def get_feature_extractor(
self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
) -> WhisperFeatureExtractor:
hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
def get_feature_extractor(self,
**kwargs: object) -> WhisperFeatureExtractor:
hf_processor = self.get_hf_processor(**kwargs)
feature_extractor = hf_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
......
......@@ -69,8 +69,6 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend, current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import (
cached_image_processor_from_config)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
......@@ -752,73 +750,15 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2VLConfig)
def get_hf_processor(
self,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None,
**kwargs: object,
) -> Qwen2VLProcessor:
def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor:
return self.ctx.get_hf_processor(
Qwen2VLProcessor,
image_processor=self.get_image_processor(min_pixels=min_pixels,
max_pixels=max_pixels,
size=size,
use_fast=kwargs.get(
"use_fast", True)),
use_fast=kwargs.pop("use_fast", True),
**kwargs,
)
def _get_image_processor_kwargs(
self,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None,
**kwargs: object,
):
mm_config = self.ctx.model_config.get_multimodal_config()
if mm_config.mm_processor_kwargs:
kwargs.update(mm_config.mm_processor_kwargs)
if min_pixels is not None:
kwargs["min_pixels"] = min_pixels
if size is None:
size = {"shortest_edge": min_pixels}
else:
size["shortest_edge"] = min_pixels
if max_pixels is not None:
kwargs["max_pixels"] = max_pixels
if size is None:
size = {"longest_edge": max_pixels}
else:
size["longest_edge"] = max_pixels
if size is not None:
kwargs["size"] = size
return kwargs
def get_image_processor(
self,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None,
**kwargs: object,
) -> Qwen2VLImageProcessor:
kwargs["use_fast"] = kwargs.get("use_fast", True)
return cached_image_processor_from_config(
self.ctx.model_config,
**self._get_image_processor_kwargs(min_pixels=min_pixels,
max_pixels=max_pixels,
size=size,
**kwargs),
)
def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
return self.get_hf_processor(**kwargs).image_processor
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
......@@ -1023,20 +963,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2VLMultiModalDataParser()
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
mm_kwargs = self.info._get_image_processor_kwargs(**mm_kwargs)
return self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data),
dict(**mm_kwargs, **tok_kwargs),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
......
......@@ -7,9 +7,8 @@
# Copyright (c) 2025 Skywork
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, TypeVar, Union
from typing import Literal, Optional, TypedDict, Union
import torch
import torch.nn as nn
......@@ -232,7 +231,7 @@ def image_to_pixel_values_skyworkr1v(
return pixel_values
class BaseSkyworkR1VProcessor(ABC):
class SkyworkR1VProcessor:
"""
This model doesn't define its own HF processor,
so we implement our own one here.
......@@ -279,17 +278,18 @@ class BaseSkyworkR1VProcessor(ABC):
self.use_thumbnail: bool = config.use_thumbnail
@property
@abstractmethod
def image_token_id(self) -> int:
raise NotImplementedError
return self.tokenizer.get_vocab()[IMG_CONTEXT]
@abstractmethod
def get_image_repl(
self,
feature_size: int,
num_patches: Optional[int],
) -> PromptUpdateDetails[str]:
raise NotImplementedError
repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
def resolve_min_max_num(
self,
......@@ -426,35 +426,15 @@ class BaseSkyworkR1VProcessor(ABC):
}
class SkyworkR1VProcessor(BaseSkyworkR1VProcessor):
@property
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[IMG_CONTEXT]
def get_image_repl(
self,
feature_size: int,
num_patches: Optional[int],
) -> PromptUpdateDetails[str]:
repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
class SkyworkR1VProcessingInfo(BaseProcessingInfo):
class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
@abstractmethod
def get_hf_processor(
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
**kwargs: object,
) -> BaseSkyworkR1VProcessor:
raise NotImplementedError
def get_hf_processor(self, **kwargs: object) -> SkyworkR1VProcessor:
return self.ctx.init_processor(
SkyworkR1VProcessor,
config=self.get_hf_config(),
tokenizer=self.get_tokenizer(),
**kwargs,
)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
......@@ -464,7 +444,7 @@ class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
processor: Optional[BaseSkyworkR1VProcessor],
processor: Optional[SkyworkR1VProcessor],
) -> int:
if processor is None:
processor = self.get_hf_processor()
......@@ -500,10 +480,8 @@ class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
return largest_feature_pinpoint
_I = TypeVar("_I", bound=BaseSkyworkR1VProcessingInfo)
class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
class SkyworkR1VDummyInputsBuilder(
BaseDummyInputsBuilder[SkyworkR1VProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
......@@ -527,7 +505,8 @@ class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
}
class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]):
class SkyworkR1VMultiModalProcessor(
BaseMultiModalProcessor[SkyworkR1VProcessingInfo]):
def _call_hf_processor(
self,
......@@ -617,31 +596,6 @@ class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]):
]
class SkyworkR1VProcessingInfo(BaseSkyworkR1VProcessingInfo):
def get_hf_processor(
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
**kwargs: object,
) -> SkyworkR1VProcessor:
if min_dynamic_patch is not None:
kwargs["min_dynamic_patch"] = min_dynamic_patch
if max_dynamic_patch is not None:
kwargs["max_dynamic_patch"] = max_dynamic_patch
if dynamic_image_size is not None:
kwargs["dynamic_image_size"] = dynamic_image_size
return self.ctx.init_processor(
SkyworkR1VProcessor,
config=self.get_hf_config(),
tokenizer=self.get_tokenizer(),
**kwargs,
)
@MULTIMODAL_REGISTRY.register_processor(
SkyworkR1VMultiModalProcessor,
info=SkyworkR1VProcessingInfo,
......
......@@ -19,15 +19,7 @@ from .idefics3 import Idefics3ProcessingInfo
class SmolVLMProcessingInfo(Idefics3ProcessingInfo):
def get_hf_processor(
self,
*,
max_image_size: Optional[dict[str, int]] = None,
**kwargs: object,
) -> SmolVLMProcessor:
if max_image_size is not None:
kwargs["max_image_size"] = max_image_size
def get_hf_processor(self, **kwargs: object) -> SmolVLMProcessor:
return self.ctx.get_hf_processor(SmolVLMProcessor, **kwargs)
def _get_image_token(
......
......@@ -178,13 +178,11 @@ class TarsierProcessingInfo(BaseProcessingInfo):
return get_vision_encoder_info(self.get_hf_config())
def get_hf_processor(self, **kwargs: object) -> TarsierProcessor:
hf_processor = self.ctx.get_hf_processor(TarsierProcessor, **kwargs)
# Patch for patch_size if needed (copied from vLLM LLaVA)
if hasattr(hf_processor,
'patch_size') and hf_processor.patch_size is None:
patch_size = self.get_vision_encoder_info().get_patch_size()
hf_processor.patch_size = patch_size
return hf_processor
vision_info = self.get_vision_encoder_info()
kwargs.setdefault("patch_size", vision_info.get_patch_size())
return self.ctx.get_hf_processor(TarsierProcessor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
......
......@@ -48,7 +48,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import is_list_of
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
......@@ -189,10 +188,6 @@ class MultiModalProcessingInfo(BaseProcessingInfo):
image_tokens = mm_tokens["num_image_tokens"][0]
return image_tokens
def get_hf_processor(self):
processor = cached_get_processor(self.ctx.model_config.model)
return processor
def get_max_image_size(self):
return 10_000, 10_000 # hardcode for arbitrary very large size
......
......@@ -71,13 +71,7 @@ UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
class UltravoxProcessingInfo(BaseProcessingInfo):
def get_hf_processor(
self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
**kwargs: object,
) -> ProcessorMixin:
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
config = self.ctx.model_config.hf_config
hf_processor = self.ctx.get_hf_processor(**kwargs)
......@@ -89,13 +83,9 @@ class UltravoxProcessingInfo(BaseProcessingInfo):
return hf_processor
def get_feature_extractor(
self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
) -> WhisperFeatureExtractor:
hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
def get_feature_extractor(self,
**kwargs: object) -> WhisperFeatureExtractor:
hf_processor = self.get_hf_processor(**kwargs)
audio_processor = hf_processor.audio_processor # type: ignore
feature_extractor = audio_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor)
......@@ -156,7 +146,7 @@ class UltravoxMultiModalProcessor(
audios = mm_data.pop("audios", [])
assert isinstance(audios, list)
feature_extractor = self.info.get_feature_extractor()
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
......
......@@ -623,23 +623,22 @@ class WhisperProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> WhisperConfig:
return self.ctx.get_hf_config(WhisperConfig)
def get_hf_processor(self,
sampling_rate: Optional[int] = None
) -> WhisperProcessor:
# HACK: Transformers 4.53.0 has issue with whisper tokenizer to
def get_hf_processor(self, **kwargs: object) -> WhisperProcessor:
# HACK: Transformers 4.53.2 has issue with whisper tokenizer to
# initialize processor. We use a monkeypatch to fix it here.
# See: https://github.com/vllm-project/vllm/issues/20224
processor_class = WhisperProcessor
tokenizer_class = ("WhisperTokenizer", "WhisperTokenizerFast")
if processor_class.tokenizer_class != tokenizer_class:
processor_class.tokenizer_class = tokenizer_class
return self.ctx.get_hf_processor(processor_class)
return self.ctx.get_hf_processor(processor_class, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": 1}
def get_feature_extractor(self) -> WhisperFeatureExtractor:
hf_processor = self.get_hf_processor()
def get_feature_extractor(self,
**kwargs: object) -> WhisperFeatureExtractor:
hf_processor = self.get_hf_processor(**kwargs)
feature_extractor = hf_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
......@@ -702,7 +701,7 @@ class WhisperMultiModalProcessor(
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
feature_extractor = self.info.get_feature_extractor()
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
mm_data = dict(audio=mm_data.pop("audios"))
mm_kwargs = dict(
**mm_kwargs,
......
......@@ -4,9 +4,15 @@
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from transformers import (AutoFeatureExtractor, AutoImageProcessor,
AutoProcessor)
from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.image_processing_utils import BaseImageProcessor
from transformers.processing_utils import ProcessorMixin
from typing_extensions import TypeVar
from vllm.utils import get_allowed_kwarg_only_overrides
if TYPE_CHECKING:
from vllm.config import ModelConfig
......@@ -33,23 +39,42 @@ class HashableList(list):
return hash(tuple(self))
def _merge_mm_kwargs(model_config: "ModelConfig", **kwargs):
mm_config = model_config.get_multimodal_config()
base_kwargs = mm_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
def _get_processor_factory_fn(processor_cls: Union[type, tuple[type, ...]]):
if isinstance(processor_cls, tuple) or processor_cls == ProcessorMixin:
return AutoProcessor.from_pretrained
if hasattr(processor_cls, "from_pretrained"):
return processor_cls.from_pretrained
return processor_cls
merged_kwargs = {**base_kwargs, **kwargs}
def _merge_mm_kwargs(
model_config: "ModelConfig",
processor_cls: Union[type, tuple[type, ...]],
/,
**kwargs,
):
mm_config = model_config.get_multimodal_config()
merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs)
factory = _get_processor_factory_fn(processor_cls)
allowed_kwargs = get_allowed_kwarg_only_overrides(
factory,
merged_kwargs,
requires_kw_only=False,
allow_var_kwargs=True,
)
# NOTE: Pythonic dict is not hashable and will raise unhashable type
# error when calling `cached_get_processor`, therefore we need to
# wrap it to a hashable dict.
for key, value in merged_kwargs.items():
for key, value in allowed_kwargs.items():
if isinstance(value, dict):
merged_kwargs[key] = HashableDict(value)
allowed_kwargs[key] = HashableDict(value)
if isinstance(value, list):
merged_kwargs[key] = HashableList(value)
return merged_kwargs
allowed_kwargs[key] = HashableList(value)
return allowed_kwargs
def get_processor(
......@@ -61,21 +86,29 @@ def get_processor(
**kwargs: Any,
) -> _P:
"""Load a processor for the given model name via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor
processor_factory = (AutoProcessor if processor_cls == ProcessorMixin or
isinstance(processor_cls, tuple) else processor_cls)
if revision is None:
revision = "main"
try:
processor = processor_factory.from_pretrained(
processor_name,
*args,
revision=revision,
trust_remote_code=trust_remote_code,
**kwargs,
)
if isinstance(processor_cls, tuple) or processor_cls == ProcessorMixin:
processor = AutoProcessor.from_pretrained(
processor_name,
*args,
revision=revision,
trust_remote_code=trust_remote_code,
**kwargs,
)
elif issubclass(processor_cls, ProcessorMixin):
processor = processor_cls.from_pretrained(
processor_name,
*args,
revision=revision,
trust_remote_code=trust_remote_code,
**kwargs,
)
else:
# Processors that are standalone classes unrelated to HF
processor = processor_cls(*args, **kwargs)
except ValueError as e:
# If the error pertains to the processor class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
......@@ -112,7 +145,7 @@ def cached_processor_from_config(
revision=model_config.revision,
trust_remote_code=model_config.trust_remote_code,
processor_cls=processor_cls, # type: ignore[arg-type]
**_merge_mm_kwargs(model_config, **kwargs),
**_merge_mm_kwargs(model_config, processor_cls, **kwargs),
)
......@@ -125,10 +158,6 @@ def get_feature_extractor(
):
"""Load an audio feature extractor for the given model name
via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoFeatureExtractor
from transformers.feature_extraction_utils import FeatureExtractionMixin
try:
feature_extractor = AutoFeatureExtractor.from_pretrained(
processor_name,
......@@ -164,7 +193,7 @@ def cached_feature_extractor_from_config(
model_config.model,
revision=model_config.revision,
trust_remote_code=model_config.trust_remote_code,
**_merge_mm_kwargs(model_config, **kwargs),
**_merge_mm_kwargs(model_config, AutoFeatureExtractor, **kwargs),
)
......@@ -176,11 +205,6 @@ def get_image_processor(
**kwargs: Any,
):
"""Load an image processor for the given model name via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoImageProcessor
from transformers.image_processing_utils import BaseImageProcessor
try:
processor = AutoImageProcessor.from_pretrained(
processor_name,
......@@ -217,5 +241,5 @@ def cached_image_processor_from_config(
model_config.model,
revision=model_config.revision,
trust_remote_code=model_config.trust_remote_code,
**_merge_mm_kwargs(model_config, **kwargs),
**_merge_mm_kwargs(model_config, AutoImageProcessor, **kwargs),
)
......@@ -2010,49 +2010,6 @@ def supports_kw(
return False
def resolve_mm_processor_kwargs(
init_kwargs: Optional[Mapping[str, object]],
inference_kwargs: Optional[Mapping[str, object]],
callable: Callable[..., object],
*,
requires_kw_only: bool = True,
allow_var_kwargs: bool = False,
) -> dict[str, Any]:
"""Applies filtering to eliminate invalid mm_processor_kwargs, i.e.,
those who are not explicit keywords to the given callable (of one is
given; otherwise no filtering is done), then merges the kwarg dicts,
giving priority to inference_kwargs if there are any collisions.
In the case that no kwarg overrides are provided, returns an empty
dict so that it can still be kwarg expanded into the callable later on.
If allow_var_kwargs=True, allows for things that can be expanded into
kwargs as long as they aren't naming collision for var_kwargs or potential
positional arguments.
"""
# Filter inference time multimodal processor kwargs provided
runtime_mm_kwargs = get_allowed_kwarg_only_overrides(
callable,
overrides=inference_kwargs,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs,
)
# Filter init time multimodal processor kwargs provided
init_mm_kwargs = get_allowed_kwarg_only_overrides(
callable,
overrides=init_kwargs,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs,
)
# Merge the final processor kwargs, prioritizing inference
# time values over the initialization time values.
mm_processor_kwargs = {**init_mm_kwargs, **runtime_mm_kwargs}
return mm_processor_kwargs
def get_allowed_kwarg_only_overrides(
callable: Callable[..., object],
overrides: Optional[Mapping[str, object]],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment