Unverified Commit 987506bc authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Simplify dummy data generation (#35025)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent c645e9a2
...@@ -357,15 +357,13 @@ class Qwen2_5OmniThinkerDummyInputsBuilder( ...@@ -357,15 +357,13 @@ class Qwen2_5OmniThinkerDummyInputsBuilder(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None, mm_options: Mapping[str, BaseDummyOptions],
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_audios = mm_counts.get("audio", 0) num_audios = mm_counts.get("audio", 0)
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0) num_videos = mm_counts.get("video", 0)
mm_processor_kwargs = mm_processor_kwargs or {} feature_extractor = self.info.get_feature_extractor()
feature_extractor = self.info.get_feature_extractor(**mm_processor_kwargs)
target_audio_length = ( target_audio_length = (
min( min(
...@@ -375,16 +373,14 @@ class Qwen2_5OmniThinkerDummyInputsBuilder( ...@@ -375,16 +373,14 @@ class Qwen2_5OmniThinkerDummyInputsBuilder(
* feature_extractor.sampling_rate * feature_extractor.sampling_rate
) )
target_width, target_height = self.info.get_image_size_with_most_features( target_width, target_height = self.info.get_image_size_with_most_features()
max_pixels=mm_processor_kwargs.get("max_pixels", None),
)
target_num_frames = self.info.get_num_frames_with_most_features( target_num_frames = self.info.get_num_frames_with_most_features(
seq_len, mm_counts seq_len, mm_counts
) )
image_overrides = mm_options.get("image") if mm_options else None image_overrides = mm_options.get("image")
video_overrides = mm_options.get("video") if mm_options else None video_overrides = mm_options.get("video")
audio_overrides = mm_options.get("audio") if mm_options else None audio_overrides = mm_options.get("audio")
mm_data = { mm_data = {
"audio": self._get_dummy_audios( "audio": self._get_dummy_audios(
......
...@@ -195,22 +195,21 @@ class Qwen2AudioDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2AudioProcessingIn ...@@ -195,22 +195,21 @@ class Qwen2AudioDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2AudioProcessingIn
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None, mm_options: Mapping[str, BaseDummyOptions],
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
feature_extractor = self.info.get_feature_extractor( feature_extractor = self.info.get_feature_extractor()
**(mm_processor_kwargs or {})
)
sampling_rate = feature_extractor.sampling_rate sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate audio_len = feature_extractor.chunk_length * sampling_rate
num_audios = mm_counts.get("audio", 0) num_audios = mm_counts.get("audio", 0)
audio_overrides = mm_options.get("audio") if mm_options else None audio_overrides = mm_options.get("audio")
return { return {
"audio": self._get_dummy_audios( "audio": self._get_dummy_audios(
length=audio_len, num_audios=num_audios, overrides=audio_overrides length=audio_len,
num_audios=num_audios,
overrides=audio_overrides,
) )
} }
......
...@@ -925,9 +925,14 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): ...@@ -925,9 +925,14 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
patch_size = vision_config.patch_size patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size merge_size = vision_config.spatial_merge_size
if max_pixels is None: if max_pixels is None:
image_processor = self.get_image_processor() image_processor = self.get_image_processor()
max_pixels = image_processor.size["longest_edge"]
mm_kwargs = self.ctx.get_merged_mm_kwargs({})
size = mm_kwargs.get("size", image_processor.size)
max_pixels = size["longest_edge"]
unit = patch_size * merge_size unit = patch_size * merge_size
max_seq_len = max_pixels // (unit * unit) max_seq_len = max_pixels // (unit * unit)
...@@ -1027,22 +1032,18 @@ class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]): ...@@ -1027,22 +1032,18 @@ class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None, mm_options: Mapping[str, BaseDummyOptions],
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0) num_videos = mm_counts.get("video", 0)
mm_processor_kwargs = mm_processor_kwargs or {} target_width, target_height = self.info.get_image_size_with_most_features()
target_width, target_height = self.info.get_image_size_with_most_features(
max_pixels=mm_processor_kwargs.get("max_pixels", None)
)
target_num_frames = self.info.get_num_frames_with_most_features( target_num_frames = self.info.get_num_frames_with_most_features(
seq_len, mm_counts seq_len, mm_counts
) )
image_overrides = mm_options.get("image") if mm_options else None image_overrides = mm_options.get("image")
video_overrides = mm_options.get("video") if mm_options else None video_overrides = mm_options.get("video")
return { return {
"image": self._get_dummy_images( "image": self._get_dummy_images(
......
...@@ -146,14 +146,11 @@ class Qwen3ASRDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3ASRProcessingInfo]) ...@@ -146,14 +146,11 @@ class Qwen3ASRDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3ASRProcessingInfo])
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None, mm_options: Mapping[str, BaseDummyOptions],
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_audios = mm_counts.get("audio", 0) num_audios = mm_counts.get("audio", 0)
feature_extractor = self.info.get_feature_extractor( feature_extractor = self.info.get_feature_extractor()
**(mm_processor_kwargs or {})
)
target_audio_length = ( target_audio_length = (
min( min(
...@@ -163,7 +160,7 @@ class Qwen3ASRDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3ASRProcessingInfo]) ...@@ -163,7 +160,7 @@ class Qwen3ASRDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3ASRProcessingInfo])
* feature_extractor.sampling_rate * feature_extractor.sampling_rate
) )
audio_overrides = mm_options.get("audio") if mm_options else None audio_overrides = mm_options.get("audio")
return { return {
"audio": self._get_dummy_audios( "audio": self._get_dummy_audios(
......
...@@ -703,11 +703,18 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo): ...@@ -703,11 +703,18 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> int: ) -> int:
video_processor = self.get_video_processor() video_processor = self.get_video_processor()
video_max_pixels = video_processor.size["longest_edge"]
mm_kwargs = self.ctx.get_merged_mm_kwargs({})
video_size = mm_kwargs.get("size", video_processor.size)
temporal_patch_size = mm_kwargs.get(
"temporal_patch_size", video_processor.temporal_patch_size
)
# video_max_pixels contains the temporal compression factor, # video_max_pixels contains the temporal compression factor,
# so we divide by 2 to get the maximum number of image pixels. # so we divide by 2 to get the maximum number of image pixels.
video_max_pixels = video_size["longest_edge"]
target_width, target_height = self.get_image_size_with_most_features( target_width, target_height = self.get_image_size_with_most_features(
max_pixels=video_max_pixels // video_processor.temporal_patch_size max_pixels=video_max_pixels // temporal_patch_size
) )
num_video_soft_tokens = self.get_num_video_tokens( num_video_soft_tokens = self.get_num_video_tokens(
image_width=target_width, image_width=target_width,
...@@ -789,19 +796,15 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]): ...@@ -789,19 +796,15 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None, mm_options: Mapping[str, BaseDummyOptions],
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0) num_videos = mm_counts.get("video", 0)
image_overrides = mm_options.get("image") if mm_options else None image_overrides = mm_options.get("image")
video_overrides = mm_options.get("video") if mm_options else None video_overrides = mm_options.get("video")
mm_processor_kwargs = mm_processor_kwargs or {}
target_image_width, target_image_height = ( target_image_width, target_image_height = (
self.info.get_image_size_with_most_features( self.info.get_image_size_with_most_features()
max_pixels=mm_processor_kwargs.get("max_pixels", None),
)
) )
# treat videos as special images # treat videos as special images
...@@ -826,13 +829,20 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]): ...@@ -826,13 +829,20 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
target_num_frames = min(target_num_frames, num_frames_override) target_num_frames = min(target_num_frames, num_frames_override)
target_num_frames = max(target_num_frames, 2) target_num_frames = max(target_num_frames, 2)
video_processor = self.info.get_video_processor(**(mm_processor_kwargs or {})) video_processor = self.info.get_video_processor()
video_max_pixels = video_processor.size["longest_edge"]
mm_kwargs = self.info.ctx.get_merged_mm_kwargs({})
video_size = mm_kwargs.get("size", video_processor.size)
temporal_patch_size = mm_kwargs.get(
"temporal_patch_size", video_processor.temporal_patch_size
)
# video_max_pixels contains the temporal compression factor, # video_max_pixels contains the temporal compression factor,
# so we divide by 2 to get the maximum number of image pixels. # so we divide by 2 to get the maximum number of image pixels.
video_max_pixels = video_size["longest_edge"]
target_video_width, target_video_height = ( target_video_width, target_video_height = (
self.info.get_image_size_with_most_features( self.info.get_image_size_with_most_features(
max_pixels=video_max_pixels // video_processor.temporal_patch_size max_pixels=video_max_pixels // temporal_patch_size
) )
) )
target_video_size, _ = self.info._get_vision_info( target_video_size, _ = self.info._get_vision_info(
......
...@@ -617,8 +617,7 @@ class QwenVLDummyInputsBuilder(BaseDummyInputsBuilder[QwenVLProcessingInfo]): ...@@ -617,8 +617,7 @@ class QwenVLDummyInputsBuilder(BaseDummyInputsBuilder[QwenVLProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None, mm_options: Mapping[str, BaseDummyOptions],
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
vision_config = hf_config.visual vision_config = hf_config.visual
...@@ -626,7 +625,7 @@ class QwenVLDummyInputsBuilder(BaseDummyInputsBuilder[QwenVLProcessingInfo]): ...@@ -626,7 +625,7 @@ class QwenVLDummyInputsBuilder(BaseDummyInputsBuilder[QwenVLProcessingInfo]):
target_width = target_height = vision_config["image_size"] target_width = target_height = vision_config["image_size"]
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
image_overrides = mm_options.get("image") if mm_options else None image_overrides = mm_options.get("image")
return { return {
"image": self._get_dummy_images( "image": self._get_dummy_images(
......
...@@ -40,14 +40,13 @@ class RVLDummyInputsBuilder(LlavaDummyInputsBuilder[RVLProcessingInfo]): ...@@ -40,14 +40,13 @@ class RVLDummyInputsBuilder(LlavaDummyInputsBuilder[RVLProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None, mm_options: Mapping[str, BaseDummyOptions],
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
target_width, target_height = self.info.get_image_size_with_most_features() target_width, target_height = self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None image_overrides = mm_options.get("image")
return { return {
"image": self._get_dummy_images( "image": self._get_dummy_images(
......
...@@ -158,14 +158,13 @@ class SiglipDummyInputsBuilder(BaseDummyInputsBuilder[SiglipProcessingInfo]): ...@@ -158,14 +158,13 @@ class SiglipDummyInputsBuilder(BaseDummyInputsBuilder[SiglipProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None, mm_options: Mapping[str, BaseDummyOptions],
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
target_width, target_height = self.info.get_image_size_with_most_features() target_width, target_height = self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None image_overrides = mm_options.get("image")
return { return {
"image": self._get_dummy_images( "image": self._get_dummy_images(
......
...@@ -529,13 +529,12 @@ class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[SkyworkR1VProcessingIn ...@@ -529,13 +529,12 @@ class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[SkyworkR1VProcessingIn
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None, mm_options: Mapping[str, BaseDummyOptions],
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
target_width, target_height = self.info.get_image_size_with_most_features() target_width, target_height = self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
image_overrides = mm_options.get("image") if mm_options else None image_overrides = mm_options.get("image")
return { return {
"image": self._get_dummy_images( "image": self._get_dummy_images(
......
...@@ -564,13 +564,12 @@ class Step3VLDummyInputsBuilder(BaseDummyInputsBuilder[Step3VLProcessingInfo]): ...@@ -564,13 +564,12 @@ class Step3VLDummyInputsBuilder(BaseDummyInputsBuilder[Step3VLProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None, mm_options: Mapping[str, BaseDummyOptions],
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
target_width, target_height = self.info.get_image_size_with_most_features() target_width, target_height = self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
image_overrides = mm_options.get("image") if mm_options else None image_overrides = mm_options.get("image")
return { return {
"image": self._get_dummy_images( "image": self._get_dummy_images(
......
...@@ -154,8 +154,7 @@ class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]): ...@@ -154,8 +154,7 @@ class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None, mm_options: Mapping[str, BaseDummyOptions],
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
# Dummy data is generated based on the 'input' section # Dummy data is generated based on the 'input' section
# defined in the HF configuration file # defined in the HF configuration file
......
...@@ -101,14 +101,13 @@ class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder[MultiModalProcessingIn ...@@ -101,14 +101,13 @@ class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder[MultiModalProcessingIn
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Mapping[str, "BaseDummyOptions"] | None = None, mm_options: Mapping[str, "BaseDummyOptions"],
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
target_width, target_height = self.info.get_max_image_size() target_width, target_height = self.info.get_max_image_size()
image_overrides = mm_options.get("image") if mm_options else None image_overrides = mm_options.get("image")
return { return {
"image": self._get_dummy_images( "image": self._get_dummy_images(
......
...@@ -164,12 +164,9 @@ class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]) ...@@ -164,12 +164,9 @@ class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo])
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None, mm_options: Mapping[str, BaseDummyOptions],
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
feature_extractor = self.info.get_feature_extractor( feature_extractor = self.info.get_feature_extractor()
**(mm_processor_kwargs or {})
)
sampling_rate = feature_extractor.sampling_rate sampling_rate = feature_extractor.sampling_rate
audio_len = ( audio_len = (
...@@ -177,11 +174,13 @@ class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]) ...@@ -177,11 +174,13 @@ class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo])
) )
num_audios = mm_counts.get("audio", 0) num_audios = mm_counts.get("audio", 0)
audio_overrides = mm_options.get("audio") if mm_options else None audio_overrides = mm_options.get("audio")
return { return {
"audio": self._get_dummy_audios( "audio": self._get_dummy_audios(
length=audio_len, num_audios=num_audios, overrides=audio_overrides length=audio_len,
num_audios=num_audios,
overrides=audio_overrides,
) )
} }
......
...@@ -218,18 +218,19 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]): ...@@ -218,18 +218,19 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None, mm_options: Mapping[str, BaseDummyOptions],
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_audios = mm_counts.get("audio", 0) num_audios = mm_counts.get("audio", 0)
target_length = self.info.get_max_audio_array_len() target_length = self.info.get_max_audio_array_len()
audio_overrides = mm_options.get("audio") if mm_options else None audio_overrides = mm_options.get("audio")
return { return {
"audio": self._get_dummy_audios( "audio": self._get_dummy_audios(
length=target_length, num_audios=num_audios, overrides=audio_overrides length=target_length,
num_audios=num_audios,
overrides=audio_overrides,
) )
} }
...@@ -237,8 +238,7 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]): ...@@ -237,8 +238,7 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None, mm_options: Mapping[str, BaseDummyOptions],
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
......
...@@ -695,22 +695,21 @@ class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]): ...@@ -695,22 +695,21 @@ class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None, mm_options: Mapping[str, BaseDummyOptions],
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
feature_extractor = self.info.get_feature_extractor( feature_extractor = self.info.get_feature_extractor()
**(mm_processor_kwargs or {})
)
sampling_rate = feature_extractor.sampling_rate sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate audio_len = feature_extractor.chunk_length * sampling_rate
num_audios = mm_counts.get("audio", 0) num_audios = mm_counts.get("audio", 0)
audio_overrides = mm_options.get("audio") if mm_options else None audio_overrides = mm_options.get("audio")
return { return {
"audio": self._get_dummy_audios( "audio": self._get_dummy_audios(
length=audio_len, num_audios=num_audios, overrides=audio_overrides length=audio_len,
num_audios=num_audios,
overrides=audio_overrides,
) )
} }
......
...@@ -266,11 +266,14 @@ class InputProcessingContext: ...@@ -266,11 +266,14 @@ class InputProcessingContext:
if isinstance(tokenizer, MistralTokenizer): if isinstance(tokenizer, MistralTokenizer):
tokenizer = tokenizer.transformers_tokenizer tokenizer = tokenizer.transformers_tokenizer
merged_kwargs = self.get_merged_mm_kwargs(kwargs)
merged_kwargs.pop("tokenizer", None)
return cached_processor_from_config( return cached_processor_from_config(
self.model_config, self.model_config,
processor_cls=typ, processor_cls=typ,
tokenizer=tokenizer, tokenizer=tokenizer,
**kwargs, **merged_kwargs,
) )
def init_processor( def init_processor(
...@@ -283,12 +286,7 @@ class InputProcessingContext: ...@@ -283,12 +286,7 @@ class InputProcessingContext:
Initialize a HuggingFace-like processor class, merging the Initialize a HuggingFace-like processor class, merging the
keyword arguments with those in the model's configuration. keyword arguments with those in the model's configuration.
""" """
mm_config = self.model_config.get_multimodal_config() merged_kwargs = self.get_merged_mm_kwargs(kwargs)
base_kwargs = mm_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
merged_kwargs = {**base_kwargs, **kwargs}
return typ(**merged_kwargs) return typ(**merged_kwargs)
......
...@@ -62,8 +62,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): ...@@ -62,8 +62,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None, mm_options: Mapping[str, BaseDummyOptions],
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
""" """
Build the multimodal input which, after processing, results in Build the multimodal input which, after processing, results in
...@@ -83,8 +82,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): ...@@ -83,8 +82,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None, mm_options: Mapping[str, BaseDummyOptions],
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
""" """
Build the input which, after processing, results in Build the input which, after processing, results in
...@@ -94,16 +92,9 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): ...@@ -94,16 +92,9 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
seq_len: Sequence length seq_len: Sequence length
mm_counts: Count of items per modality mm_counts: Count of items per modality
mm_options: Configurable options per modality (optional) mm_options: Configurable options per modality (optional)
mm_processor_kwargs: Additional keyword arguments
for hf_processor (optional)
""" """
dummy_text = self.get_dummy_text(mm_counts) dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data( dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
seq_len,
mm_counts,
mm_options,
mm_processor_kwargs=mm_processor_kwargs,
)
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data, validate=False) dummy_mm_items = self.info.parse_mm_data(dummy_mm_data, validate=False)
tokenization_kwargs = {"truncation": False} tokenization_kwargs = {"truncation": False}
...@@ -111,7 +102,6 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): ...@@ -111,7 +102,6 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
return ProcessorInputs( return ProcessorInputs(
prompt=dummy_text, prompt=dummy_text,
mm_items=dummy_mm_items, mm_items=dummy_mm_items,
hf_processor_mm_kwargs=mm_processor_kwargs or {},
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
) )
......
...@@ -5,7 +5,6 @@ from dataclasses import dataclass ...@@ -5,7 +5,6 @@ from dataclasses import dataclass
from multiprocessing.synchronize import Lock as LockType from multiprocessing.synchronize import Lock as LockType
from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypeVar, cast from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypeVar, cast
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.observability import ObservabilityConfig from vllm.config.observability import ObservabilityConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
...@@ -99,27 +98,6 @@ class MultiModalRegistry: ...@@ -99,27 +98,6 @@ class MultiModalRegistry:
A registry that dispatches data processing according to the model. A registry that dispatches data processing according to the model.
""" """
def _extract_mm_options(
self,
model_config: "ModelConfig",
) -> Mapping[str, BaseDummyOptions] | None:
"""
Extract multimodal dummy options from model config.
Returns None if no configurable options are found, otherwise returns
a mapping of modality names to their dummy options.
"""
if not model_config.multimodal_config:
return None
mm_options = {
m: opt
for m in model_config.multimodal_config.limit_per_prompt
if (opt := model_config.multimodal_config.get_dummy_options(m)) is not None
}
return mm_options if len(mm_options) > 0 else None
def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool: def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
""" """
Checks if the model supports multimodal inputs. Checks if the model supports multimodal inputs.
...@@ -261,8 +239,7 @@ class MultiModalRegistry: ...@@ -261,8 +239,7 @@ class MultiModalRegistry:
processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs( processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
seq_len=seq_len, seq_len=seq_len,
mm_counts=mm_counts, mm_counts=mm_counts,
mm_options=self._extract_mm_options(model_config), mm_options=mm_config.limit_per_prompt,
mm_processor_kwargs=mm_config.mm_processor_kwargs,
) )
mm_inputs = processor.apply( mm_inputs = processor.apply(
prompt=processor_inputs.prompt, prompt=processor_inputs.prompt,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment