Unverified Commit 82de9b9d authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Automatically resolve HF processor init kwargs (#22005)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent ad57f23f
...@@ -331,10 +331,8 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): ...@@ -331,10 +331,8 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
return hf_processor return hf_processor
def get_image_processor(self): def get_image_processor(self, **kwargs: object):
hf_processor = self.get_hf_processor() return self.get_hf_processor(**kwargs).image_processor
image_processor = hf_processor.image_processor # type: ignore
return image_processor
def get_model_version(self): def get_model_version(self):
return get_version_by_config(self.get_hf_config()) return get_version_by_config(self.get_hf_config())
......
...@@ -533,7 +533,7 @@ class Mllama4ProcessingInfo(BaseProcessingInfo): ...@@ -533,7 +533,7 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self, **kwargs: object) -> Llama4Processor: def get_hf_processor(self, **kwargs: object) -> Llama4Processor:
return self.ctx.get_hf_processor(Llama4Processor, return self.ctx.get_hf_processor(Llama4Processor,
use_fast=True, use_fast=kwargs.pop("use_fast", True),
**kwargs) **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
......
...@@ -137,34 +137,16 @@ class NemotronVLProcessor(InternVLProcessor): ...@@ -137,34 +137,16 @@ class NemotronVLProcessor(InternVLProcessor):
class NemotronVLProcessingInfo(BaseInternVLProcessingInfo): class NemotronVLProcessingInfo(BaseInternVLProcessingInfo):
"""Processing info for Nemotron VL models.""" """Processing info for Nemotron VL models."""
def get_hf_processor( def get_hf_processor(self, **kwargs: object) -> NemotronVLProcessor:
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
**kwargs: object,
) -> NemotronVLProcessor:
if min_dynamic_patch is not None:
kwargs["min_dynamic_patch"] = min_dynamic_patch
if max_dynamic_patch is not None:
kwargs["max_dynamic_patch"] = max_dynamic_patch
if dynamic_image_size is not None:
kwargs["dynamic_image_size"] = dynamic_image_size
image_processor = self.get_image_processor()
return self.ctx.init_processor( return self.ctx.init_processor(
NemotronVLProcessor, NemotronVLProcessor,
config=self.get_hf_config(), config=self.get_hf_config(),
tokenizer=self.get_tokenizer(), tokenizer=self.get_tokenizer(),
image_processor=image_processor, image_processor=self.get_image_processor(),
**kwargs, **kwargs,
) )
def get_image_processor( def get_image_processor(self, **kwargs: object):
self,
**kwargs: object,
):
return cached_image_processor_from_config( return cached_image_processor_from_config(
self.ctx.model_config, self.ctx.model_config,
**kwargs, **kwargs,
......
...@@ -63,21 +63,7 @@ class NVLMProcessor(BaseInternVLProcessor): ...@@ -63,21 +63,7 @@ class NVLMProcessor(BaseInternVLProcessor):
class NVLMProcessingInfo(BaseInternVLProcessingInfo): class NVLMProcessingInfo(BaseInternVLProcessingInfo):
def get_hf_processor( def get_hf_processor(self, **kwargs: object) -> NVLMProcessor:
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
**kwargs: object,
) -> NVLMProcessor:
if min_dynamic_patch is not None:
kwargs["min_dynamic_patch"] = min_dynamic_patch
if max_dynamic_patch is not None:
kwargs["max_dynamic_patch"] = max_dynamic_patch
if dynamic_image_size is not None:
kwargs["dynamic_image_size"] = dynamic_image_size
return self.ctx.init_processor( return self.ctx.init_processor(
NVLMProcessor, NVLMProcessor,
config=self.get_hf_config(), config=self.get_hf_config(),
......
...@@ -25,7 +25,7 @@ import torch ...@@ -25,7 +25,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.nn.functional import gumbel_softmax, pad, softmax from torch.nn.functional import gumbel_softmax, pad, softmax
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig from transformers import BatchFeature, PretrainedConfig
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
...@@ -245,11 +245,12 @@ class VisualEmbedding(torch.nn.Embedding): ...@@ -245,11 +245,12 @@ class VisualEmbedding(torch.nn.Embedding):
class OvisProcessingInfo(BaseProcessingInfo): class OvisProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self, **kwargs): def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor( return self.ctx.get_hf_processor(
OvisProcessor, OvisProcessor,
image_pad_token=self.get_image_pad_token(), image_pad_token=self.get_image_pad_token(),
image_segment_len=self.get_image_segment_len(), image_segment_len=self.get_image_segment_len(),
**kwargs,
) )
def get_image_segment_len(self) -> int: def get_image_segment_len(self) -> int:
...@@ -269,9 +270,6 @@ class OvisProcessingInfo(BaseProcessingInfo): ...@@ -269,9 +270,6 @@ class OvisProcessingInfo(BaseProcessingInfo):
text_model_type = hf_text_config.model_type text_model_type = hf_text_config.model_type
return IMAGE_PAD_TOKEN_MAP.get(text_model_type) return IMAGE_PAD_TOKEN_MAP.get(text_model_type)
def get_image_processor(self) -> BaseImageProcessor:
return self.get_hf_processor().image_processor # type: ignore
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
......
...@@ -318,17 +318,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): ...@@ -318,17 +318,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
class Phi3VProcessingInfo(BaseProcessingInfo): class Phi3VProcessingInfo(BaseProcessingInfo):
def get_hf_processor(
self,
*,
num_crops: Optional[int] = None,
**kwargs: object,
) -> ProcessorMixin:
if num_crops is not None:
kwargs["num_crops"] = num_crops
return self.ctx.get_hf_processor(**kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
......
...@@ -696,19 +696,12 @@ class Phi4MMProcessingInfo(BaseProcessingInfo): ...@@ -696,19 +696,12 @@ class Phi4MMProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> Phi4MultimodalConfig: def get_hf_config(self) -> Phi4MultimodalConfig:
return self.ctx.get_hf_config(Phi4MultimodalConfig) return self.ctx.get_hf_config(Phi4MultimodalConfig)
def get_hf_processor( def get_hf_processor(self, **kwargs: object) -> Phi4MMProcessor:
self, return self.ctx.get_hf_processor(Phi4MMProcessor, **kwargs)
*,
dynamic_hd: Optional[int] = None,
**kwargs: object,
) -> Phi4MMProcessor:
if dynamic_hd is not None:
kwargs["dynamic_hd"] = dynamic_hd
return self.ctx.get_hf_processor(**kwargs)
def get_feature_extractor(self) -> Phi4MultimodalFeatureExtractor: def get_feature_extractor(
return self.get_hf_processor().audio_processor self, **kwargs: object) -> Phi4MultimodalFeatureExtractor:
return self.get_hf_processor(**kwargs).audio_processor
def get_image_processor( def get_image_processor(
self, self,
...@@ -1007,7 +1000,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): ...@@ -1007,7 +1000,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
if audio_data: if audio_data:
audio_features = processed_outputs['audio_input_features'] audio_features = processed_outputs['audio_input_features']
sr = self.info.get_feature_extractor().sampling_rate sr = self.info.get_feature_extractor(**mm_kwargs).sampling_rate
feature_sizes = [ feature_sizes = [
self.info.get_audio_num_frames(len(audio), sr) self.info.get_audio_num_frames(len(audio), sr)
for audio in audio_data for audio in audio_data
...@@ -1043,7 +1036,8 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): ...@@ -1043,7 +1036,8 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
audio_token_id = tokenizer.vocab[tokenizer.audio_token] audio_token_id = tokenizer.vocab[tokenizer.audio_token]
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
audio_processor = self.info.get_feature_extractor() audio_processor = self.info.get_feature_extractor(
**hf_processor_mm_kwargs)
def get_image_replacement_phi4mm(item_idx: int): def get_image_replacement_phi4mm(item_idx: int):
images = mm_items.get_items( images = mm_items.get_items(
......
...@@ -459,17 +459,6 @@ def cat_with_pad(tensors, dim, padding_value=0): ...@@ -459,17 +459,6 @@ def cat_with_pad(tensors, dim, padding_value=0):
class Phi4MMProcessingInfo(BaseProcessingInfo): class Phi4MMProcessingInfo(BaseProcessingInfo):
def get_hf_processor(
self,
*,
dynamic_hd: Optional[int] = None,
**kwargs: object,
) -> ProcessorMixin:
if dynamic_hd is not None:
kwargs["dynamic_hd"] = dynamic_hd
return self.ctx.get_hf_processor(**kwargs)
@property @property
def image_tokens(self) -> list[str]: def image_tokens(self) -> list[str]:
return [f"<|image_{i+1}|>" for i in range(100)] return [f"<|image_{i+1}|>" for i in range(100)]
...@@ -487,8 +476,9 @@ class Phi4MMProcessingInfo(BaseProcessingInfo): ...@@ -487,8 +476,9 @@ class Phi4MMProcessingInfo(BaseProcessingInfo):
image_processor = processor.image_processor image_processor = processor.image_processor
return image_processor.dynamic_hd return image_processor.dynamic_hd
def get_feature_extractor(self) -> SequenceFeatureExtractor: def get_feature_extractor(self,
return self.get_hf_processor().audio_processor **kwargs: object) -> SequenceFeatureExtractor:
return self.get_hf_processor(**kwargs).audio_processor
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None, "image": None} return {"audio": None, "image": None}
...@@ -769,7 +759,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): ...@@ -769,7 +759,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
sr = self.info.get_feature_extractor().sampling_rate sr = self.info.get_feature_extractor(**mm_kwargs).sampling_rate
if (audio_data := mm_data.get("audios", [])): if (audio_data := mm_data.get("audios", [])):
mm_data['audios'] = [(data, sr) for data in audio_data] mm_data['audios'] = [(data, sr) for data in audio_data]
...@@ -816,7 +806,8 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): ...@@ -816,7 +806,8 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
image_tokens: list[str] = self.info.image_tokens # type: ignore image_tokens: list[str] = self.info.image_tokens # type: ignore
audio_tokens: list[str] = self.info.audio_tokens # type: ignore audio_tokens: list[str] = self.info.audio_tokens # type: ignore
feature_extractor = self.info.get_feature_extractor() feature_extractor = self.info.get_feature_extractor(
**hf_processor_mm_kwargs)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
def get_image_replacement_phi4mm(item_idx: int): def get_image_replacement_phi4mm(item_idx: int):
......
...@@ -132,50 +132,15 @@ class Qwen2_5OmniThinkerProcessingInfo(Qwen2AudioProcessingInfo, ...@@ -132,50 +132,15 @@ class Qwen2_5OmniThinkerProcessingInfo(Qwen2AudioProcessingInfo,
def get_hf_config(self): def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2_5OmniConfig).thinker_config return self.ctx.get_hf_config(Qwen2_5OmniConfig).thinker_config
def get_hf_processor( def get_hf_processor(self, **kwargs: object) -> Qwen2_5OmniProcessor:
self, return self.ctx.get_hf_processor(
*, Qwen2_5OmniProcessor,
sampling_rate: Optional[int] = None, use_fast=kwargs.pop("use_fast", True),
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None,
fps: Optional[Union[float, list[float]]] = None,
**kwargs: object,
) -> Qwen2_5OmniProcessor:
if fps is not None:
kwargs["fps"] = fps
# Monkey patch for Transformers v4.53
processor_class = Qwen2_5OmniProcessor
if processor_class.image_processor_class != "AutoImageProcessor":
processor_class.image_processor_class = "AutoImageProcessor"
if processor_class.video_processor_class != "AutoVideoProcessor":
processor_class.video_processor_class = "AutoVideoProcessor"
processor = self.ctx.get_hf_processor(
processor_class,
image_processor=self.get_image_processor(min_pixels=min_pixels,
max_pixels=max_pixels,
size=size,
use_fast=kwargs.get(
"use_fast", True)),
**kwargs, **kwargs,
) )
if not hasattr(processor, "audio_token"):
processor.audio_token = "<|AUDIO|>" def get_feature_extractor(self, **kwargs: object):
if not hasattr(processor, "image_token"): hf_processor = self.get_hf_processor(**kwargs)
processor.image_token = "<|IMAGE|>"
if not hasattr(processor, "video_token"):
processor.video_token = "<|VIDEO|>"
return processor
def get_feature_extractor(
self,
*,
sampling_rate: Optional[int] = None,
**kwargs: object,
):
hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
feature_extractor = hf_processor.feature_extractor # type: ignore feature_extractor = hf_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor) assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor return feature_extractor
......
...@@ -780,25 +780,10 @@ class Qwen2_5_VLProcessingInfo(Qwen2VLProcessingInfo): ...@@ -780,25 +780,10 @@ class Qwen2_5_VLProcessingInfo(Qwen2VLProcessingInfo):
def get_hf_config(self): def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2_5_VLConfig) return self.ctx.get_hf_config(Qwen2_5_VLConfig)
def get_hf_processor( def get_hf_processor(self, **kwargs: object) -> Qwen2_5_VLProcessor:
self,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None,
fps: Optional[Union[float, list[float]]] = None,
**kwargs: object,
) -> Qwen2_5_VLProcessor:
if fps is not None:
kwargs["fps"] = fps
return self.ctx.get_hf_processor( return self.ctx.get_hf_processor(
Qwen2_5_VLProcessor, Qwen2_5_VLProcessor,
image_processor=self.get_image_processor(min_pixels=min_pixels, use_fast=kwargs.pop("use_fast", True),
max_pixels=max_pixels,
size=size,
use_fast=kwargs.get(
"use_fast", True)),
**kwargs, **kwargs,
) )
......
...@@ -86,22 +86,12 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo): ...@@ -86,22 +86,12 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo):
def get_hf_config(self): def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2AudioConfig) return self.ctx.get_hf_config(Qwen2AudioConfig)
def get_hf_processor( def get_hf_processor(self, **kwargs: object) -> Qwen2AudioProcessor:
self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
**kwargs: object,
) -> Qwen2AudioProcessor:
return self.ctx.get_hf_processor(Qwen2AudioProcessor, **kwargs) return self.ctx.get_hf_processor(Qwen2AudioProcessor, **kwargs)
def get_feature_extractor( def get_feature_extractor(self,
self, **kwargs: object) -> WhisperFeatureExtractor:
*, hf_processor = self.get_hf_processor(**kwargs)
# Ignored in initialization
sampling_rate: Optional[int] = None,
) -> WhisperFeatureExtractor:
hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
feature_extractor = hf_processor.feature_extractor # type: ignore feature_extractor = hf_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor) assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor return feature_extractor
......
...@@ -69,8 +69,6 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder ...@@ -69,8 +69,6 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend, current_platform from vllm.platforms import _Backend, current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import (
cached_image_processor_from_config)
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
...@@ -752,73 +750,15 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): ...@@ -752,73 +750,15 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
def get_hf_config(self): def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2VLConfig) return self.ctx.get_hf_config(Qwen2VLConfig)
def get_hf_processor( def get_hf_processor(self, **kwargs: object) -> Qwen2VLProcessor:
self,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None,
**kwargs: object,
) -> Qwen2VLProcessor:
return self.ctx.get_hf_processor( return self.ctx.get_hf_processor(
Qwen2VLProcessor, Qwen2VLProcessor,
image_processor=self.get_image_processor(min_pixels=min_pixels, use_fast=kwargs.pop("use_fast", True),
max_pixels=max_pixels,
size=size,
use_fast=kwargs.get(
"use_fast", True)),
**kwargs, **kwargs,
) )
def _get_image_processor_kwargs( def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
self, return self.get_hf_processor(**kwargs).image_processor
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None,
**kwargs: object,
):
mm_config = self.ctx.model_config.get_multimodal_config()
if mm_config.mm_processor_kwargs:
kwargs.update(mm_config.mm_processor_kwargs)
if min_pixels is not None:
kwargs["min_pixels"] = min_pixels
if size is None:
size = {"shortest_edge": min_pixels}
else:
size["shortest_edge"] = min_pixels
if max_pixels is not None:
kwargs["max_pixels"] = max_pixels
if size is None:
size = {"longest_edge": max_pixels}
else:
size["longest_edge"] = max_pixels
if size is not None:
kwargs["size"] = size
return kwargs
def get_image_processor(
self,
*,
min_pixels: Optional[int] = None,
max_pixels: Optional[int] = None,
size: Optional[dict[str, int]] = None,
**kwargs: object,
) -> Qwen2VLImageProcessor:
kwargs["use_fast"] = kwargs.get("use_fast", True)
return cached_image_processor_from_config(
self.ctx.model_config,
**self._get_image_processor_kwargs(min_pixels=min_pixels,
max_pixels=max_pixels,
size=size,
**kwargs),
)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None} return {"image": None, "video": None}
...@@ -1023,20 +963,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] ...@@ -1023,20 +963,6 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2VLMultiModalDataParser() return Qwen2VLMultiModalDataParser()
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
mm_kwargs = self.info._get_image_processor_kwargs(**mm_kwargs)
return self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data),
dict(**mm_kwargs, **tok_kwargs),
)
def _get_prompt_updates( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
......
...@@ -7,9 +7,8 @@ ...@@ -7,9 +7,8 @@
# Copyright (c) 2025 Skywork # Copyright (c) 2025 Skywork
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, TypeVar, Union from typing import Literal, Optional, TypedDict, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -232,7 +231,7 @@ def image_to_pixel_values_skyworkr1v( ...@@ -232,7 +231,7 @@ def image_to_pixel_values_skyworkr1v(
return pixel_values return pixel_values
class BaseSkyworkR1VProcessor(ABC): class SkyworkR1VProcessor:
""" """
This model doesn't define its own HF processor, This model doesn't define its own HF processor,
so we implement our own one here. so we implement our own one here.
...@@ -279,17 +278,18 @@ class BaseSkyworkR1VProcessor(ABC): ...@@ -279,17 +278,18 @@ class BaseSkyworkR1VProcessor(ABC):
self.use_thumbnail: bool = config.use_thumbnail self.use_thumbnail: bool = config.use_thumbnail
@property @property
@abstractmethod
def image_token_id(self) -> int: def image_token_id(self) -> int:
raise NotImplementedError return self.tokenizer.get_vocab()[IMG_CONTEXT]
@abstractmethod
def get_image_repl( def get_image_repl(
self, self,
feature_size: int, feature_size: int,
num_patches: Optional[int], num_patches: Optional[int],
) -> PromptUpdateDetails[str]: ) -> PromptUpdateDetails[str]:
raise NotImplementedError repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
def resolve_min_max_num( def resolve_min_max_num(
self, self,
...@@ -426,35 +426,15 @@ class BaseSkyworkR1VProcessor(ABC): ...@@ -426,35 +426,15 @@ class BaseSkyworkR1VProcessor(ABC):
} }
class SkyworkR1VProcessor(BaseSkyworkR1VProcessor): class SkyworkR1VProcessingInfo(BaseProcessingInfo):
@property
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[IMG_CONTEXT]
def get_image_repl(
self,
feature_size: int,
num_patches: Optional[int],
) -> PromptUpdateDetails[str]:
repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo): def get_hf_processor(self, **kwargs: object) -> SkyworkR1VProcessor:
return self.ctx.init_processor(
@abstractmethod SkyworkR1VProcessor,
def get_hf_processor( config=self.get_hf_config(),
self, tokenizer=self.get_tokenizer(),
*, **kwargs,
min_dynamic_patch: Optional[int] = None, )
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
**kwargs: object,
) -> BaseSkyworkR1VProcessor:
raise NotImplementedError
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
...@@ -464,7 +444,7 @@ class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo): ...@@ -464,7 +444,7 @@ class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
processor: Optional[BaseSkyworkR1VProcessor], processor: Optional[SkyworkR1VProcessor],
) -> int: ) -> int:
if processor is None: if processor is None:
processor = self.get_hf_processor() processor = self.get_hf_processor()
...@@ -500,10 +480,8 @@ class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo): ...@@ -500,10 +480,8 @@ class BaseSkyworkR1VProcessingInfo(BaseProcessingInfo):
return largest_feature_pinpoint return largest_feature_pinpoint
_I = TypeVar("_I", bound=BaseSkyworkR1VProcessingInfo) class SkyworkR1VDummyInputsBuilder(
BaseDummyInputsBuilder[SkyworkR1VProcessingInfo]):
class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
...@@ -527,7 +505,8 @@ class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[_I]): ...@@ -527,7 +505,8 @@ class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
} }
class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]): class SkyworkR1VMultiModalProcessor(
BaseMultiModalProcessor[SkyworkR1VProcessingInfo]):
def _call_hf_processor( def _call_hf_processor(
self, self,
...@@ -617,31 +596,6 @@ class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -617,31 +596,6 @@ class SkyworkR1VMultiModalProcessor(BaseMultiModalProcessor[_I]):
] ]
class SkyworkR1VProcessingInfo(BaseSkyworkR1VProcessingInfo):
def get_hf_processor(
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
**kwargs: object,
) -> SkyworkR1VProcessor:
if min_dynamic_patch is not None:
kwargs["min_dynamic_patch"] = min_dynamic_patch
if max_dynamic_patch is not None:
kwargs["max_dynamic_patch"] = max_dynamic_patch
if dynamic_image_size is not None:
kwargs["dynamic_image_size"] = dynamic_image_size
return self.ctx.init_processor(
SkyworkR1VProcessor,
config=self.get_hf_config(),
tokenizer=self.get_tokenizer(),
**kwargs,
)
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
SkyworkR1VMultiModalProcessor, SkyworkR1VMultiModalProcessor,
info=SkyworkR1VProcessingInfo, info=SkyworkR1VProcessingInfo,
......
...@@ -19,15 +19,7 @@ from .idefics3 import Idefics3ProcessingInfo ...@@ -19,15 +19,7 @@ from .idefics3 import Idefics3ProcessingInfo
class SmolVLMProcessingInfo(Idefics3ProcessingInfo): class SmolVLMProcessingInfo(Idefics3ProcessingInfo):
def get_hf_processor( def get_hf_processor(self, **kwargs: object) -> SmolVLMProcessor:
self,
*,
max_image_size: Optional[dict[str, int]] = None,
**kwargs: object,
) -> SmolVLMProcessor:
if max_image_size is not None:
kwargs["max_image_size"] = max_image_size
return self.ctx.get_hf_processor(SmolVLMProcessor, **kwargs) return self.ctx.get_hf_processor(SmolVLMProcessor, **kwargs)
def _get_image_token( def _get_image_token(
......
...@@ -178,13 +178,11 @@ class TarsierProcessingInfo(BaseProcessingInfo): ...@@ -178,13 +178,11 @@ class TarsierProcessingInfo(BaseProcessingInfo):
return get_vision_encoder_info(self.get_hf_config()) return get_vision_encoder_info(self.get_hf_config())
def get_hf_processor(self, **kwargs: object) -> TarsierProcessor: def get_hf_processor(self, **kwargs: object) -> TarsierProcessor:
hf_processor = self.ctx.get_hf_processor(TarsierProcessor, **kwargs) vision_info = self.get_vision_encoder_info()
# Patch for patch_size if needed (copied from vLLM LLaVA)
if hasattr(hf_processor, kwargs.setdefault("patch_size", vision_info.get_patch_size())
'patch_size') and hf_processor.patch_size is None:
patch_size = self.get_vision_encoder_info().get_patch_size() return self.ctx.get_hf_processor(TarsierProcessor, **kwargs)
hf_processor.patch_size = patch_size
return hf_processor
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
......
...@@ -48,7 +48,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -48,7 +48,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo) BaseProcessingInfo)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP, from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
...@@ -189,10 +188,6 @@ class MultiModalProcessingInfo(BaseProcessingInfo): ...@@ -189,10 +188,6 @@ class MultiModalProcessingInfo(BaseProcessingInfo):
image_tokens = mm_tokens["num_image_tokens"][0] image_tokens = mm_tokens["num_image_tokens"][0]
return image_tokens return image_tokens
def get_hf_processor(self):
processor = cached_get_processor(self.ctx.model_config.model)
return processor
def get_max_image_size(self): def get_max_image_size(self):
return 10_000, 10_000 # hardcode for arbitrary very large size return 10_000, 10_000 # hardcode for arbitrary very large size
......
...@@ -71,13 +71,7 @@ UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs, ...@@ -71,13 +71,7 @@ UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs,
class UltravoxProcessingInfo(BaseProcessingInfo): class UltravoxProcessingInfo(BaseProcessingInfo):
def get_hf_processor( def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
self,
*,
# Ignored in initialization
sampling_rate: Optional[int] = None,
**kwargs: object,
) -> ProcessorMixin:
config = self.ctx.model_config.hf_config config = self.ctx.model_config.hf_config
hf_processor = self.ctx.get_hf_processor(**kwargs) hf_processor = self.ctx.get_hf_processor(**kwargs)
...@@ -89,13 +83,9 @@ class UltravoxProcessingInfo(BaseProcessingInfo): ...@@ -89,13 +83,9 @@ class UltravoxProcessingInfo(BaseProcessingInfo):
return hf_processor return hf_processor
def get_feature_extractor( def get_feature_extractor(self,
self, **kwargs: object) -> WhisperFeatureExtractor:
*, hf_processor = self.get_hf_processor(**kwargs)
# Ignored in initialization
sampling_rate: Optional[int] = None,
) -> WhisperFeatureExtractor:
hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
audio_processor = hf_processor.audio_processor # type: ignore audio_processor = hf_processor.audio_processor # type: ignore
feature_extractor = audio_processor.feature_extractor # type: ignore feature_extractor = audio_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor) assert isinstance(feature_extractor, WhisperFeatureExtractor)
...@@ -156,7 +146,7 @@ class UltravoxMultiModalProcessor( ...@@ -156,7 +146,7 @@ class UltravoxMultiModalProcessor(
audios = mm_data.pop("audios", []) audios = mm_data.pop("audios", [])
assert isinstance(audios, list) assert isinstance(audios, list)
feature_extractor = self.info.get_feature_extractor() feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
mm_kwargs = dict( mm_kwargs = dict(
**mm_kwargs, **mm_kwargs,
sampling_rate=feature_extractor.sampling_rate, sampling_rate=feature_extractor.sampling_rate,
......
...@@ -623,23 +623,22 @@ class WhisperProcessingInfo(BaseProcessingInfo): ...@@ -623,23 +623,22 @@ class WhisperProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> WhisperConfig: def get_hf_config(self) -> WhisperConfig:
return self.ctx.get_hf_config(WhisperConfig) return self.ctx.get_hf_config(WhisperConfig)
def get_hf_processor(self, def get_hf_processor(self, **kwargs: object) -> WhisperProcessor:
sampling_rate: Optional[int] = None # HACK: Transformers 4.53.2 has issue with whisper tokenizer to
) -> WhisperProcessor:
# HACK: Transformers 4.53.0 has issue with whisper tokenizer to
# initialize processor. We use a monkeypatch to fix it here. # initialize processor. We use a monkeypatch to fix it here.
# See: https://github.com/vllm-project/vllm/issues/20224 # See: https://github.com/vllm-project/vllm/issues/20224
processor_class = WhisperProcessor processor_class = WhisperProcessor
tokenizer_class = ("WhisperTokenizer", "WhisperTokenizerFast") tokenizer_class = ("WhisperTokenizer", "WhisperTokenizerFast")
if processor_class.tokenizer_class != tokenizer_class: if processor_class.tokenizer_class != tokenizer_class:
processor_class.tokenizer_class = tokenizer_class processor_class.tokenizer_class = tokenizer_class
return self.ctx.get_hf_processor(processor_class) return self.ctx.get_hf_processor(processor_class, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": 1} return {"audio": 1}
def get_feature_extractor(self) -> WhisperFeatureExtractor: def get_feature_extractor(self,
hf_processor = self.get_hf_processor() **kwargs: object) -> WhisperFeatureExtractor:
hf_processor = self.get_hf_processor(**kwargs)
feature_extractor = hf_processor.feature_extractor # type: ignore feature_extractor = hf_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor) assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor return feature_extractor
...@@ -702,7 +701,7 @@ class WhisperMultiModalProcessor( ...@@ -702,7 +701,7 @@ class WhisperMultiModalProcessor(
tok_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object],
) -> BatchFeature: ) -> BatchFeature:
if mm_data: if mm_data:
feature_extractor = self.info.get_feature_extractor() feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
mm_data = dict(audio=mm_data.pop("audios")) mm_data = dict(audio=mm_data.pop("audios"))
mm_kwargs = dict( mm_kwargs = dict(
**mm_kwargs, **mm_kwargs,
......
...@@ -4,9 +4,15 @@ ...@@ -4,9 +4,15 @@
from functools import lru_cache from functools import lru_cache
from typing import TYPE_CHECKING, Any, Optional, Union, cast from typing import TYPE_CHECKING, Any, Optional, Union, cast
from transformers import (AutoFeatureExtractor, AutoImageProcessor,
AutoProcessor)
from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.image_processing_utils import BaseImageProcessor
from transformers.processing_utils import ProcessorMixin from transformers.processing_utils import ProcessorMixin
from typing_extensions import TypeVar from typing_extensions import TypeVar
from vllm.utils import get_allowed_kwarg_only_overrides
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -33,23 +39,42 @@ class HashableList(list): ...@@ -33,23 +39,42 @@ class HashableList(list):
return hash(tuple(self)) return hash(tuple(self))
def _merge_mm_kwargs(model_config: "ModelConfig", **kwargs): def _get_processor_factory_fn(processor_cls: Union[type, tuple[type, ...]]):
mm_config = model_config.get_multimodal_config() if isinstance(processor_cls, tuple) or processor_cls == ProcessorMixin:
base_kwargs = mm_config.mm_processor_kwargs return AutoProcessor.from_pretrained
if base_kwargs is None: if hasattr(processor_cls, "from_pretrained"):
base_kwargs = {} return processor_cls.from_pretrained
return processor_cls
merged_kwargs = {**base_kwargs, **kwargs} def _merge_mm_kwargs(
model_config: "ModelConfig",
processor_cls: Union[type, tuple[type, ...]],
/,
**kwargs,
):
mm_config = model_config.get_multimodal_config()
merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs)
factory = _get_processor_factory_fn(processor_cls)
allowed_kwargs = get_allowed_kwarg_only_overrides(
factory,
merged_kwargs,
requires_kw_only=False,
allow_var_kwargs=True,
)
# NOTE: Pythonic dict is not hashable and will raise unhashable type # NOTE: Pythonic dict is not hashable and will raise unhashable type
# error when calling `cached_get_processor`, therefore we need to # error when calling `cached_get_processor`, therefore we need to
# wrap it to a hashable dict. # wrap it to a hashable dict.
for key, value in merged_kwargs.items(): for key, value in allowed_kwargs.items():
if isinstance(value, dict): if isinstance(value, dict):
merged_kwargs[key] = HashableDict(value) allowed_kwargs[key] = HashableDict(value)
if isinstance(value, list): if isinstance(value, list):
merged_kwargs[key] = HashableList(value) allowed_kwargs[key] = HashableList(value)
return merged_kwargs
return allowed_kwargs
def get_processor( def get_processor(
...@@ -61,21 +86,29 @@ def get_processor( ...@@ -61,21 +86,29 @@ def get_processor(
**kwargs: Any, **kwargs: Any,
) -> _P: ) -> _P:
"""Load a processor for the given model name via HuggingFace.""" """Load a processor for the given model name via HuggingFace."""
# don't put this import at the top level if revision is None:
# it will call torch.cuda.device_count() revision = "main"
from transformers import AutoProcessor
processor_factory = (AutoProcessor if processor_cls == ProcessorMixin or
isinstance(processor_cls, tuple) else processor_cls)
try: try:
processor = processor_factory.from_pretrained( if isinstance(processor_cls, tuple) or processor_cls == ProcessorMixin:
processor_name, processor = AutoProcessor.from_pretrained(
*args, processor_name,
revision=revision, *args,
trust_remote_code=trust_remote_code, revision=revision,
**kwargs, trust_remote_code=trust_remote_code,
) **kwargs,
)
elif issubclass(processor_cls, ProcessorMixin):
processor = processor_cls.from_pretrained(
processor_name,
*args,
revision=revision,
trust_remote_code=trust_remote_code,
**kwargs,
)
else:
# Processors that are standalone classes unrelated to HF
processor = processor_cls(*args, **kwargs)
except ValueError as e: except ValueError as e:
# If the error pertains to the processor class not existing or not # If the error pertains to the processor class not existing or not
# currently being imported, suggest using the --trust-remote-code flag. # currently being imported, suggest using the --trust-remote-code flag.
...@@ -112,7 +145,7 @@ def cached_processor_from_config( ...@@ -112,7 +145,7 @@ def cached_processor_from_config(
revision=model_config.revision, revision=model_config.revision,
trust_remote_code=model_config.trust_remote_code, trust_remote_code=model_config.trust_remote_code,
processor_cls=processor_cls, # type: ignore[arg-type] processor_cls=processor_cls, # type: ignore[arg-type]
**_merge_mm_kwargs(model_config, **kwargs), **_merge_mm_kwargs(model_config, processor_cls, **kwargs),
) )
...@@ -125,10 +158,6 @@ def get_feature_extractor( ...@@ -125,10 +158,6 @@ def get_feature_extractor(
): ):
"""Load an audio feature extractor for the given model name """Load an audio feature extractor for the given model name
via HuggingFace.""" via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoFeatureExtractor
from transformers.feature_extraction_utils import FeatureExtractionMixin
try: try:
feature_extractor = AutoFeatureExtractor.from_pretrained( feature_extractor = AutoFeatureExtractor.from_pretrained(
processor_name, processor_name,
...@@ -164,7 +193,7 @@ def cached_feature_extractor_from_config( ...@@ -164,7 +193,7 @@ def cached_feature_extractor_from_config(
model_config.model, model_config.model,
revision=model_config.revision, revision=model_config.revision,
trust_remote_code=model_config.trust_remote_code, trust_remote_code=model_config.trust_remote_code,
**_merge_mm_kwargs(model_config, **kwargs), **_merge_mm_kwargs(model_config, AutoFeatureExtractor, **kwargs),
) )
...@@ -176,11 +205,6 @@ def get_image_processor( ...@@ -176,11 +205,6 @@ def get_image_processor(
**kwargs: Any, **kwargs: Any,
): ):
"""Load an image processor for the given model name via HuggingFace.""" """Load an image processor for the given model name via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoImageProcessor
from transformers.image_processing_utils import BaseImageProcessor
try: try:
processor = AutoImageProcessor.from_pretrained( processor = AutoImageProcessor.from_pretrained(
processor_name, processor_name,
...@@ -217,5 +241,5 @@ def cached_image_processor_from_config( ...@@ -217,5 +241,5 @@ def cached_image_processor_from_config(
model_config.model, model_config.model,
revision=model_config.revision, revision=model_config.revision,
trust_remote_code=model_config.trust_remote_code, trust_remote_code=model_config.trust_remote_code,
**_merge_mm_kwargs(model_config, **kwargs), **_merge_mm_kwargs(model_config, AutoImageProcessor, **kwargs),
) )
...@@ -2010,49 +2010,6 @@ def supports_kw( ...@@ -2010,49 +2010,6 @@ def supports_kw(
return False return False
def resolve_mm_processor_kwargs(
init_kwargs: Optional[Mapping[str, object]],
inference_kwargs: Optional[Mapping[str, object]],
callable: Callable[..., object],
*,
requires_kw_only: bool = True,
allow_var_kwargs: bool = False,
) -> dict[str, Any]:
"""Applies filtering to eliminate invalid mm_processor_kwargs, i.e.,
those who are not explicit keywords to the given callable (of one is
given; otherwise no filtering is done), then merges the kwarg dicts,
giving priority to inference_kwargs if there are any collisions.
In the case that no kwarg overrides are provided, returns an empty
dict so that it can still be kwarg expanded into the callable later on.
If allow_var_kwargs=True, allows for things that can be expanded into
kwargs as long as they aren't naming collision for var_kwargs or potential
positional arguments.
"""
# Filter inference time multimodal processor kwargs provided
runtime_mm_kwargs = get_allowed_kwarg_only_overrides(
callable,
overrides=inference_kwargs,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs,
)
# Filter init time multimodal processor kwargs provided
init_mm_kwargs = get_allowed_kwarg_only_overrides(
callable,
overrides=init_kwargs,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs,
)
# Merge the final processor kwargs, prioritizing inference
# time values over the initialization time values.
mm_processor_kwargs = {**init_mm_kwargs, **runtime_mm_kwargs}
return mm_processor_kwargs
def get_allowed_kwarg_only_overrides( def get_allowed_kwarg_only_overrides(
callable: Callable[..., object], callable: Callable[..., object],
overrides: Optional[Mapping[str, object]], overrides: Optional[Mapping[str, object]],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment