Unverified Commit 0ab06100 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Multimodal] Expose `mm_processor_kwargs` for `DummyInputsBuilder` (#34330)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent ffb3d553
......@@ -618,6 +618,7 @@ class QwenVLDummyInputsBuilder(BaseDummyInputsBuilder[QwenVLProcessingInfo]):
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict:
hf_config = self.info.get_hf_config()
vision_config = hf_config.visual
......
......@@ -41,6 +41,7 @@ class RVLDummyInputsBuilder(LlavaDummyInputsBuilder[RVLProcessingInfo]):
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
......
......@@ -155,6 +155,7 @@ class SiglipDummyInputsBuilder(BaseDummyInputsBuilder[SiglipProcessingInfo]):
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
......
......@@ -533,6 +533,7 @@ class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[SkyworkR1VProcessingIn
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict:
target_width, target_height = self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
......
......@@ -565,6 +565,7 @@ class Step3VLDummyInputsBuilder(BaseDummyInputsBuilder[Step3VLProcessingInfo]):
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict:
target_width, target_height = self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
......
......@@ -154,6 +154,7 @@ class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict:
# Dummy data is generated based on the 'input' section
# defined in the HF configuration file
......
......@@ -98,6 +98,7 @@ class MultiModalDummyInputsBuilder(BaseDummyInputsBuilder[MultiModalProcessingIn
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, "BaseDummyOptions"] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
......
......@@ -161,8 +161,11 @@ class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo])
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict:
feature_extractor = self.info.get_feature_extractor()
feature_extractor = self.info.get_feature_extractor(
**(mm_processor_kwargs or {})
)
sampling_rate = feature_extractor.sampling_rate
audio_len = (
......
......@@ -220,6 +220,7 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict:
num_audios = mm_counts.get("audio", 0)
......@@ -238,6 +239,7 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> ProcessorInputs:
tokenizer = self.info.get_tokenizer()
......
......@@ -685,8 +685,11 @@ class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict:
feature_extractor = self.info.get_feature_extractor()
feature_extractor = self.info.get_feature_extractor(
**(mm_processor_kwargs or {})
)
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
......
......@@ -63,6 +63,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> MultiModalDataDict:
"""
Build the multimodal input which, after processing, results in
......@@ -83,6 +84,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
) -> ProcessorInputs:
"""
Build the input which, after processing, results in
......@@ -92,9 +94,16 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
seq_len: Sequence length
mm_counts: Count of items per modality
mm_options: Configurable options per modality (optional)
mm_processor_kwargs: Additional keyword arguments
for hf_processor (optional)
"""
dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
dummy_mm_data = self.get_dummy_mm_data(
seq_len,
mm_counts,
mm_options,
mm_processor_kwargs=mm_processor_kwargs,
)
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data, validate=False)
tokenization_kwargs = {"truncation": False}
......@@ -102,6 +111,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
return ProcessorInputs(
prompt=dummy_text,
mm_items=dummy_mm_items,
hf_processor_mm_kwargs=mm_processor_kwargs or {},
tokenization_kwargs=tokenization_kwargs,
)
......
......@@ -257,10 +257,12 @@ class MultiModalRegistry:
if processor is None:
processor = self.create_processor(model_config, cache=cache)
mm_config = model_config.get_multimodal_config()
processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
seq_len=seq_len,
mm_counts=mm_counts,
mm_options=self._extract_mm_options(model_config),
mm_processor_kwargs=mm_config.mm_processor_kwargs,
)
mm_inputs = processor.apply(
prompt=processor_inputs.prompt,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment