Unverified Commit 51550179 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Define MM data parser in processing info instead of processor itself (#33260)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 07ea184f
...@@ -143,6 +143,7 @@ def test_qwen3_omni_get_updates_use_audio_in_video( ...@@ -143,6 +143,7 @@ def test_qwen3_omni_get_updates_use_audio_in_video(
# Create processing info # Create processing info
info = Qwen3OmniMoeThinkerProcessingInfo(mock_ctx) info = Qwen3OmniMoeThinkerProcessingInfo(mock_ctx)
info._get_expected_hidden_size = lambda: 100
info.get_hf_config = Mock(return_value=mock_qwen3_omni_config) info.get_hf_config = Mock(return_value=mock_qwen3_omni_config)
info.get_hf_processor = Mock(return_value=mock_processor) info.get_hf_processor = Mock(return_value=mock_processor)
info.get_tokenizer = Mock(return_value=mock_tokenizer) info.get_tokenizer = Mock(return_value=mock_tokenizer)
......
...@@ -192,6 +192,22 @@ class AudioFlamingo3MultiModalProjector(nn.Module): ...@@ -192,6 +192,22 @@ class AudioFlamingo3MultiModalProjector(nn.Module):
return hidden_states return hidden_states
class AudioFlamingo3MultiModalDataParser(MultiModalDataParser):
def _parse_audio_data(
self,
data: dict[str, torch.Tensor] | ModalityData[Any],
) -> ModalityDataItems[Any, Any] | None:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="audio",
required_fields={"audio_embeds"},
fields_factory=_audioflamingo3_field_config,
)
return super()._parse_audio_data(data)
class AudioFlamingo3ProcessingInfo(BaseProcessingInfo): class AudioFlamingo3ProcessingInfo(BaseProcessingInfo):
def get_hf_config(self): def get_hf_config(self):
return self.ctx.get_hf_config(AudioFlamingo3Config) return self.ctx.get_hf_config(AudioFlamingo3Config)
...@@ -204,6 +220,14 @@ class AudioFlamingo3ProcessingInfo(BaseProcessingInfo): ...@@ -204,6 +220,14 @@ class AudioFlamingo3ProcessingInfo(BaseProcessingInfo):
feature_extractor = hf_processor.feature_extractor feature_extractor = hf_processor.feature_extractor
return feature_extractor return feature_extractor
def get_data_parser(self):
feature_extractor = self.get_feature_extractor()
return AudioFlamingo3MultiModalDataParser(
target_sr=feature_extractor.sampling_rate,
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": None} return {"audio": None}
...@@ -259,30 +283,9 @@ def _audioflamingo3_field_config(hf_inputs: Mapping[str, torch.Tensor]): ...@@ -259,30 +283,9 @@ def _audioflamingo3_field_config(hf_inputs: Mapping[str, torch.Tensor]):
) )
class AudioFlamingo3MultiModalDataParser(MultiModalDataParser):
def _parse_audio_data(
self,
data: dict[str, torch.Tensor] | ModalityData[Any],
) -> ModalityDataItems[Any, Any] | None:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="audio",
required_fields={"audio_embeds"},
fields_factory=_audioflamingo3_field_config,
)
return super()._parse_audio_data(data)
class AudioFlamingo3MultiModalProcessor( class AudioFlamingo3MultiModalProcessor(
BaseMultiModalProcessor[AudioFlamingo3ProcessingInfo] BaseMultiModalProcessor[AudioFlamingo3ProcessingInfo]
): ):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return AudioFlamingo3MultiModalDataParser(
target_sr=feature_extractor.sampling_rate
)
def _call_hf_processor( def _call_hf_processor(
self, self,
prompt: str, prompt: str,
......
...@@ -227,10 +227,8 @@ class AyaVisionMultiModalProcessor(BaseMultiModalProcessor[AyaVisionProcessingIn ...@@ -227,10 +227,8 @@ class AyaVisionMultiModalProcessor(BaseMultiModalProcessor[AyaVisionProcessingIn
# HF processor pops the `num_patches` kwarg, which is needed by vLLM # HF processor pops the `num_patches` kwarg, which is needed by vLLM
if (images := mm_data.get("images")) is not None: if (images := mm_data.get("images")) is not None:
parsed_images = ( parsed_images = self.data_parser.parse_mm_data({"image": images}).get_items(
self._get_data_parser() "image", ImageProcessorItems
.parse_mm_data({"image": images})
.get_items("image", ImageProcessorItems)
) )
image_sizes = [ image_sizes = [
parsed_images.get_image_size(i) for i in range(len(parsed_images)) parsed_images.get_image_size(i) for i in range(len(parsed_images))
......
...@@ -262,10 +262,8 @@ class Cohere2VisionMultiModalProcessor( ...@@ -262,10 +262,8 @@ class Cohere2VisionMultiModalProcessor(
hf_processor = self.info.get_hf_processor(**mm_kwargs) hf_processor = self.info.get_hf_processor(**mm_kwargs)
# Fallback calculation if HF processor didn't provide num_patches # Fallback calculation if HF processor didn't provide num_patches
parsed_images = ( parsed_images = self.data_parser.parse_mm_data({"image": images}).get_items(
self._get_data_parser() "image", ImageProcessorItems
.parse_mm_data({"image": images})
.get_items("image", ImageProcessorItems)
) )
num_patches = [ num_patches = [
......
...@@ -793,6 +793,12 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo): ...@@ -793,6 +793,12 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
def get_image_processor(self, **kwargs: object): def get_image_processor(self, **kwargs: object):
return self.get_hf_processor(**kwargs).image_processor return self.get_hf_processor(**kwargs).image_processor
def get_data_parser(self):
return MultiModalDataParser(
video_needs_metadata=True,
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None, "video": None} return {"image": None, "video": None}
...@@ -947,11 +953,6 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo): ...@@ -947,11 +953,6 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
class Ernie4_5VLMultiModalProcessor(BaseMultiModalProcessor[Ernie4_5_VLProcessingInfo]): class Ernie4_5VLMultiModalProcessor(BaseMultiModalProcessor[Ernie4_5_VLProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
return MultiModalDataParser(
video_needs_metadata=True,
)
def _pixel_values_norm( def _pixel_values_norm(
self, self,
pixel_values: torch.Tensor, pixel_values: torch.Tensor,
......
...@@ -552,6 +552,29 @@ class FunAudioChatDiscreteEncoder(nn.Module): ...@@ -552,6 +552,29 @@ class FunAudioChatDiscreteEncoder(nn.Module):
class FunAudioChatProcessingInfo(BaseProcessingInfo): class FunAudioChatProcessingInfo(BaseProcessingInfo):
token_fps: int = 25 token_fps: int = 25
@cached_property
def feature_extractor(self) -> WhisperFeatureExtractor:
return WhisperFeatureExtractor.from_pretrained(self.model_id)
@cached_property
def speech_tokenizer(self) -> PreTrainedTokenizerFast:
return PreTrainedTokenizerFast.from_pretrained(
self.model_id, subfolder="speech_tokenizer"
)
def get_feature_extractor(self) -> WhisperFeatureExtractor:
return self.feature_extractor
def get_speech_tokenizer(self) -> PreTrainedTokenizerFast:
return self.speech_tokenizer
def get_data_parser(self):
return MultiModalDataParser(
target_sr=int(self.feature_extractor.sampling_rate),
target_channels=self.get_target_channels(),
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": None} return {"audio": None}
...@@ -570,22 +593,6 @@ class FunAudioChatProcessingInfo(BaseProcessingInfo): ...@@ -570,22 +593,6 @@ class FunAudioChatProcessingInfo(BaseProcessingInfo):
max_audio_tokens = int(getattr(audio_cfg, "max_source_positions", 1500)) max_audio_tokens = int(getattr(audio_cfg, "max_source_positions", 1500))
return {"audio": max_audio_tokens} return {"audio": max_audio_tokens}
@cached_property
def feature_extractor(self) -> WhisperFeatureExtractor:
return WhisperFeatureExtractor.from_pretrained(self.model_id)
@cached_property
def speech_tokenizer(self) -> PreTrainedTokenizerFast:
return PreTrainedTokenizerFast.from_pretrained(
self.model_id, subfolder="speech_tokenizer"
)
def get_feature_extractor(self) -> WhisperFeatureExtractor:
return self.feature_extractor
def get_speech_tokenizer(self) -> PreTrainedTokenizerFast:
return self.speech_tokenizer
def get_audio_group_size(self) -> int: def get_audio_group_size(self) -> int:
cfg = self.get_hf_config() cfg = self.get_hf_config()
audio_cfg = getattr(cfg, "audio_config", None) audio_cfg = getattr(cfg, "audio_config", None)
...@@ -635,13 +642,6 @@ class FunAudioChatDummyInputsBuilder( ...@@ -635,13 +642,6 @@ class FunAudioChatDummyInputsBuilder(
class FunAudioChatMultiModalProcessor( class FunAudioChatMultiModalProcessor(
BaseMultiModalProcessor[FunAudioChatProcessingInfo] BaseMultiModalProcessor[FunAudioChatProcessingInfo]
): ):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return MultiModalDataParser(
target_sr=int(feature_extractor.sampling_rate),
target_channels=self.info.get_target_channels(),
)
def _call_hf_processor( def _call_hf_processor(
self, self,
prompt: str, prompt: str,
......
...@@ -290,10 +290,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): ...@@ -290,10 +290,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
# HF processor pops the `num_crops` kwarg, which is needed by vLLM # HF processor pops the `num_crops` kwarg, which is needed by vLLM
if (images := mm_data.get("images")) is not None: if (images := mm_data.get("images")) is not None:
parsed_images = ( parsed_images = self.data_parser.parse_mm_data({"image": images}).get_items(
self._get_data_parser() "image", ImageProcessorItems
.parse_mm_data({"image": images})
.get_items("image", ImageProcessorItems)
) )
image_sizes = [ image_sizes = [
parsed_images.get_image_size(i) for i in range(len(parsed_images)) parsed_images.get_image_size(i) for i in range(len(parsed_images))
......
...@@ -107,6 +107,17 @@ class Gemma3nProcessingInfo(BaseProcessingInfo): ...@@ -107,6 +107,17 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self, **kwargs: object): def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(Gemma3nProcessor, **kwargs) return self.ctx.get_hf_processor(Gemma3nProcessor, **kwargs)
def get_feature_extractor(self, **kwargs: object) -> Gemma3nAudioFeatureExtractor:
return self.get_hf_processor(**kwargs).feature_extractor
def get_data_parser(self):
feature_extractor = self.get_feature_extractor()
return MultiModalDataParser(
target_sr=feature_extractor.sampling_rate,
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None, "audio": None} return {"image": None, "audio": None}
...@@ -200,10 +211,6 @@ class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]): ...@@ -200,10 +211,6 @@ class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]):
class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]): class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_hf_processor().feature_extractor
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
def _call_hf_processor( def _call_hf_processor(
self, self,
prompt: str, prompt: str,
......
...@@ -822,6 +822,12 @@ class Glm4vProcessingInfo(BaseProcessingInfo): ...@@ -822,6 +822,12 @@ class Glm4vProcessingInfo(BaseProcessingInfo):
def get_video_processor(self, **kwargs: object) -> Glm4vVideoProcessor: def get_video_processor(self, **kwargs: object) -> Glm4vVideoProcessor:
return self.get_hf_processor(**kwargs).video_processor return self.get_hf_processor(**kwargs).video_processor
def get_data_parser(self):
return MultiModalDataParser(
video_needs_metadata=True,
expected_hidden_size=self._get_expected_hidden_size(),
)
def _get_vision_info( def _get_vision_info(
self, self,
*, *,
...@@ -1222,9 +1228,6 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]): ...@@ -1222,9 +1228,6 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):
class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]): class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
return MultiModalDataParser(video_needs_metadata=True)
def _call_hf_processor( def _call_hf_processor(
self, self,
prompt: str, prompt: str,
......
...@@ -620,64 +620,6 @@ class GlmAsrMultiModalProjector(nn.Module): ...@@ -620,64 +620,6 @@ class GlmAsrMultiModalProjector(nn.Module):
return hidden_states return hidden_states
class GlmAsrProcessingInfo(BaseProcessingInfo):
"""
Processing information provider for GLM-ASR model.
Provides access to model configuration, processor, and feature extractor
needed for audio preprocessing and multimodal integration.
"""
def get_hf_config(self) -> GlmAsrConfig:
return self.ctx.get_hf_config(GlmAsrConfig)
def get_hf_processor(self, **kwargs: object) -> GlmAsrProcessor:
return self.ctx.get_hf_processor(GlmAsrProcessor, **kwargs)
def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
return self.get_hf_processor(**kwargs).feature_extractor
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": None}
class GlmAsrDummyInputsBuilder(BaseDummyInputsBuilder[GlmAsrProcessingInfo]):
"""
Builder for dummy inputs used in profiling and testing.
Generates dummy text prompts and audio data that match the expected
format for GLM-ASR model inputs. Used for memory profiling and
performance benchmarking.
"""
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
hf_processor = self.info.get_hf_processor()
return hf_processor.audio_token * num_audios
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
feature_extractor = self.info.get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
num_audios = mm_counts.get("audio", 0)
audio_overrides = mm_options.get("audio") if mm_options else None
max_audio_len = getattr(
self.info.get_hf_processor(), "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S
)
audio_len = int(max_audio_len * sampling_rate)
return {
"audio": self._get_dummy_audios(
length=audio_len, num_audios=num_audios, overrides=audio_overrides
)
}
def _glmasr_field_config( def _glmasr_field_config(
hf_inputs: Mapping[str, torch.Tensor], hf_inputs: Mapping[str, torch.Tensor],
) -> dict[str, MultiModalFieldConfig]: ) -> dict[str, MultiModalFieldConfig]:
...@@ -737,16 +679,78 @@ class GlmAsrMultiModalDataParser(MultiModalDataParser): ...@@ -737,16 +679,78 @@ class GlmAsrMultiModalDataParser(MultiModalDataParser):
return super()._parse_audio_data(data) return super()._parse_audio_data(data)
class GlmAsrProcessingInfo(BaseProcessingInfo):
"""
Processing information provider for GLM-ASR model.
Provides access to model configuration, processor, and feature extractor
needed for audio preprocessing and multimodal integration.
"""
def get_hf_config(self) -> GlmAsrConfig:
return self.ctx.get_hf_config(GlmAsrConfig)
def get_hf_processor(self, **kwargs: object) -> GlmAsrProcessor:
return self.ctx.get_hf_processor(GlmAsrProcessor, **kwargs)
def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
return self.get_hf_processor(**kwargs).feature_extractor
def get_data_parser(self):
feature_extractor = self.get_feature_extractor()
return GlmAsrMultiModalDataParser(
target_sr=feature_extractor.sampling_rate,
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": None}
class GlmAsrDummyInputsBuilder(BaseDummyInputsBuilder[GlmAsrProcessingInfo]):
"""
Builder for dummy inputs used in profiling and testing.
Generates dummy text prompts and audio data that match the expected
format for GLM-ASR model inputs. Used for memory profiling and
performance benchmarking.
"""
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
hf_processor = self.info.get_hf_processor()
return hf_processor.audio_token * num_audios
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
feature_extractor = self.info.get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
num_audios = mm_counts.get("audio", 0)
audio_overrides = mm_options.get("audio") if mm_options else None
max_audio_len = getattr(
self.info.get_hf_processor(), "max_audio_len", DEFAULT_MAX_AUDIO_LEN_S
)
audio_len = int(max_audio_len * sampling_rate)
return {
"audio": self._get_dummy_audios(
length=audio_len, num_audios=num_audios, overrides=audio_overrides
)
}
class GlmAsrMultiModalProcessor(BaseMultiModalProcessor["GlmAsrProcessingInfo"]): class GlmAsrMultiModalProcessor(BaseMultiModalProcessor["GlmAsrProcessingInfo"]):
""" """
GLM-ASR processor that inherits directly from BaseMultiModalProcessor GLM-ASR processor that inherits directly from BaseMultiModalProcessor
for better performance and cleaner implementation. for better performance and cleaner implementation.
""" """
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return GlmAsrMultiModalDataParser(target_sr=feature_extractor.sampling_rate)
def _calculate_chunk_counts( def _calculate_chunk_counts(
self, self,
audio_list: list[Any], audio_list: list[Any],
......
...@@ -109,6 +109,14 @@ class GraniteSpeechAudioInputs(TensorSchema): ...@@ -109,6 +109,14 @@ class GraniteSpeechAudioInputs(TensorSchema):
class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo): class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo):
def get_data_parser(self):
feature_extractor = self.get_hf_processor().audio_processor
return MultiModalDataParser(
target_sr=feature_extractor.melspec_kwargs["sample_rate"],
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": 1} return {"audio": 1}
...@@ -127,11 +135,6 @@ class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo): ...@@ -127,11 +135,6 @@ class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo):
class GraniteSpeechMultiModalProcessor( class GraniteSpeechMultiModalProcessor(
BaseMultiModalProcessor[GraniteSpeechMultiModalProcessingInfo] BaseMultiModalProcessor[GraniteSpeechMultiModalProcessingInfo]
): ):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_hf_processor().audio_processor
sampling_rate = feature_extractor.melspec_kwargs["sample_rate"]
return MultiModalDataParser(target_sr=sampling_rate)
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
......
...@@ -599,6 +599,11 @@ class HunYuanVLProcessingInfo(BaseProcessingInfo): ...@@ -599,6 +599,11 @@ class HunYuanVLProcessingInfo(BaseProcessingInfo):
) -> HunYuanVLProcessor: ) -> HunYuanVLProcessor:
return self.get_hf_processor(**kwargs).image_processor return self.get_hf_processor(**kwargs).image_processor
def get_data_parser(self):
return HunYuanVLMultiModalDataParser(
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None} return {"image": None}
...@@ -710,9 +715,6 @@ class HunYuanVLDummyInputsBuilder(BaseDummyInputsBuilder[HunYuanVLProcessingInfo ...@@ -710,9 +715,6 @@ class HunYuanVLDummyInputsBuilder(BaseDummyInputsBuilder[HunYuanVLProcessingInfo
class HunYuanVLMultiModalProcessor(BaseMultiModalProcessor[HunYuanVLProcessingInfo]): class HunYuanVLMultiModalProcessor(BaseMultiModalProcessor[HunYuanVLProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
return HunYuanVLMultiModalDataParser()
def _call_hf_processor( def _call_hf_processor(
self, self,
prompt: str, prompt: str,
......
...@@ -349,10 +349,8 @@ class Idefics3MultiModalProcessor(BaseMultiModalProcessor[Idefics3ProcessingInfo ...@@ -349,10 +349,8 @@ class Idefics3MultiModalProcessor(BaseMultiModalProcessor[Idefics3ProcessingInfo
tok_kwargs, tok_kwargs,
) )
parsed_images = ( parsed_images = self.data_parser.parse_mm_data({"image": images}).get_items(
self._get_data_parser() "image", ImageProcessorItems
.parse_mm_data({"image": images})
.get_items("image", ImageProcessorItems)
) )
image_sizes = [ image_sizes = [
parsed_images.get_image_size(i) for i in range(len(parsed_images)) parsed_images.get_image_size(i) for i in range(len(parsed_images))
......
...@@ -984,6 +984,11 @@ class KeyeProcessingInfo(BaseProcessingInfo): ...@@ -984,6 +984,11 @@ class KeyeProcessingInfo(BaseProcessingInfo):
def get_image_processor(self, **kwargs: object): def get_image_processor(self, **kwargs: object):
return self.get_hf_processor(**kwargs).image_processor return self.get_hf_processor(**kwargs).image_processor
def get_data_parser(self):
return KeyeMultiModalDataParser(
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_supported_mm_limits( def get_supported_mm_limits(
self, self,
) -> Mapping[str, int | None]: ) -> Mapping[str, int | None]:
...@@ -1183,13 +1188,11 @@ class KeyeBaseDummyInputsBuilder(BaseDummyInputsBuilder[_I]): ...@@ -1183,13 +1188,11 @@ class KeyeBaseDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
return mm_data return mm_data
class KeyeDummyInputsBuilder(KeyeBaseDummyInputsBuilder[KeyeProcessingInfo]): ... class KeyeDummyInputsBuilder(KeyeBaseDummyInputsBuilder[KeyeProcessingInfo]):
pass
class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
return KeyeMultiModalDataParser()
def _get_prompt_updates( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
......
...@@ -274,16 +274,6 @@ class KeyeVL1_5Projector(nn.Module): ...@@ -274,16 +274,6 @@ class KeyeVL1_5Projector(nn.Module):
return hidden_states.view(*dims, -1) return hidden_states.view(*dims, -1)
class KeyeVL1_5ProcessingInfo(KeyeProcessingInfo):
def get_max_frame_per_video(self) -> int:
return 2048
def get_supported_mm_limits(
self,
) -> Mapping[str, int | None]:
return {"image": None, "video": 1}
def _keye_field_config( def _keye_field_config(
hf_inputs: Mapping[str, torch.Tensor], hf_inputs: Mapping[str, torch.Tensor],
): ):
...@@ -365,10 +355,22 @@ class KeyeVL1_5MultiModalDataParser(MultiModalDataParser): ...@@ -365,10 +355,22 @@ class KeyeVL1_5MultiModalDataParser(MultiModalDataParser):
return super()._parse_video_data(data) return super()._parse_video_data(data)
class KeyeVL1_5MultiModalProcessor(BaseMultiModalProcessor[KeyeVL1_5ProcessingInfo]): class KeyeVL1_5ProcessingInfo(KeyeProcessingInfo):
def _get_data_parser(self) -> MultiModalDataParser: def get_data_parser(self):
return KeyeVL1_5MultiModalDataParser() return KeyeVL1_5MultiModalDataParser(
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_max_frame_per_video(self) -> int:
return 2048
def get_supported_mm_limits(
self,
) -> Mapping[str, int | None]:
return {"image": None, "video": 1}
class KeyeVL1_5MultiModalProcessor(BaseMultiModalProcessor[KeyeVL1_5ProcessingInfo]):
def _get_prompt_updates( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
......
...@@ -354,10 +354,8 @@ class Lfm2VLMultiModalProcessor(BaseMultiModalProcessor[Lfm2VLProcessingInfo]): ...@@ -354,10 +354,8 @@ class Lfm2VLMultiModalProcessor(BaseMultiModalProcessor[Lfm2VLProcessingInfo]):
tok_kwargs, tok_kwargs,
) )
parsed_images = ( parsed_images = self.data_parser.parse_mm_data({"image": images}).get_items(
self._get_data_parser() "image", ImageProcessorItems
.parse_mm_data({"image": images})
.get_items("image", ImageProcessorItems)
) )
image_sizes = [ image_sizes = [
parsed_images.get_image_size(i) for i in range(len(parsed_images)) parsed_images.get_image_size(i) for i in range(len(parsed_images))
......
...@@ -531,6 +531,14 @@ class MiDashengLMProcessingInfo(BaseProcessingInfo): ...@@ -531,6 +531,14 @@ class MiDashengLMProcessingInfo(BaseProcessingInfo):
feature_extractor = hf_processor.feature_extractor feature_extractor = hf_processor.feature_extractor
return feature_extractor return feature_extractor
def get_data_parser(self):
feature_extractor = self.get_feature_extractor()
return MultiModalDataParser(
target_sr=feature_extractor.sampling_rate,
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": None} return {"audio": None}
...@@ -575,10 +583,6 @@ class MiDashengLMDummyInputsBuilder(BaseDummyInputsBuilder[MiDashengLMProcessing ...@@ -575,10 +583,6 @@ class MiDashengLMDummyInputsBuilder(BaseDummyInputsBuilder[MiDashengLMProcessing
class MiDashengLMMultiModalProcessor( class MiDashengLMMultiModalProcessor(
BaseMultiModalProcessor[MiDashengLMProcessingInfo] BaseMultiModalProcessor[MiDashengLMProcessingInfo]
): ):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
def _call_hf_processor( def _call_hf_processor(
self, self,
prompt: str, prompt: str,
......
...@@ -53,7 +53,6 @@ from vllm.multimodal.parse import ( ...@@ -53,7 +53,6 @@ from vllm.multimodal.parse import (
ModalityData, ModalityData,
ModalityDataItems, ModalityDataItems,
MultiModalDataItems, MultiModalDataItems,
MultiModalDataParser,
) )
from vllm.multimodal.processing import ( from vllm.multimodal.processing import (
PromptReplacement, PromptReplacement,
...@@ -174,6 +173,12 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser): ...@@ -174,6 +173,12 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
audio_pattern = "(<audio>./</audio>)" audio_pattern = "(<audio>./</audio>)"
def get_data_parser(self):
return MiniCPMOMultiModalDataParser(
target_sr=self.get_default_audio_sampling_rate(),
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {**super().get_supported_mm_limits(), "audio": None} return {**super().get_supported_mm_limits(), "audio": None}
...@@ -274,11 +279,6 @@ class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder[MiniCPMOProcessingIn ...@@ -274,11 +279,6 @@ class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder[MiniCPMOProcessingIn
class MiniCPMOMultiModalProcessor(MiniCPMVMultiModalProcessor[MiniCPMOProcessingInfo]): class MiniCPMOMultiModalProcessor(MiniCPMVMultiModalProcessor[MiniCPMOProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
return MiniCPMOMultiModalDataParser(
target_sr=self.info.get_default_audio_sampling_rate()
)
def get_audio_prompt_texts( def get_audio_prompt_texts(
self, self,
audio_lens: int, audio_lens: int,
...@@ -300,10 +300,8 @@ class MiniCPMOMultiModalProcessor(MiniCPMVMultiModalProcessor[MiniCPMOProcessing ...@@ -300,10 +300,8 @@ class MiniCPMOMultiModalProcessor(MiniCPMVMultiModalProcessor[MiniCPMOProcessing
if (audios := mm_data.get("audios")) is None: if (audios := mm_data.get("audios")) is None:
return {} return {}
parsed_audios = ( parsed_audios = self.data_parser.parse_mm_data({"audio": audios}).get_items(
self._get_data_parser() "audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems)
.parse_mm_data({"audio": audios})
.get_items("audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems))
) )
if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems): if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems):
......
...@@ -557,6 +557,11 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): ...@@ -557,6 +557,11 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
def get_image_processor(self, **kwargs: object): def get_image_processor(self, **kwargs: object):
return self.get_hf_processor(**kwargs).image_processor return self.get_hf_processor(**kwargs).image_processor
def get_data_parser(self):
return MiniCPMVMultiModalDataParser(
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_model_version(self): def get_model_version(self):
return get_version_by_config(self.get_hf_config()) return get_version_by_config(self.get_hf_config())
...@@ -736,9 +741,6 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]): ...@@ -736,9 +741,6 @@ class MiniCPMVDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
def _get_data_parser(self) -> MultiModalDataParser:
return MiniCPMVMultiModalDataParser()
def get_image_prompt_texts(self, image_size: ImageSize, image_idx: int = 0) -> str: def get_image_prompt_texts(self, image_size: ImageSize, image_idx: int = 0) -> str:
return self.info.get_slice_image_placeholder( return self.info.get_slice_image_placeholder(
image_size, image_size,
...@@ -765,10 +767,8 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -765,10 +767,8 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
if (images := mm_data.get("images")) is None: if (images := mm_data.get("images")) is None:
return {} return {}
parsed_images = ( parsed_images = self.data_parser.parse_mm_data({"image": images}).get_items(
self._get_data_parser() "image", (MiniCPMVImageEmbeddingItems, ImageProcessorItems)
.parse_mm_data({"image": images})
.get_items("image", (MiniCPMVImageEmbeddingItems, ImageProcessorItems))
) )
if isinstance(parsed_images, MiniCPMVImageEmbeddingItems): if isinstance(parsed_images, MiniCPMVImageEmbeddingItems):
...@@ -793,10 +793,8 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -793,10 +793,8 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
if (videos := mm_data.get("videos")) is None: if (videos := mm_data.get("videos")) is None:
return {} return {}
parsed_videos = ( parsed_videos = self.data_parser.parse_mm_data({"video": videos}).get_items(
self._get_data_parser() "video", (MiniCPMVVideoEmbeddingItems, VideoProcessorItems)
.parse_mm_data({"video": videos})
.get_items("video", (MiniCPMVVideoEmbeddingItems, VideoProcessorItems))
) )
if isinstance(parsed_videos, MiniCPMVVideoEmbeddingItems): if isinstance(parsed_videos, MiniCPMVVideoEmbeddingItems):
......
...@@ -620,10 +620,8 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]) ...@@ -620,10 +620,8 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo])
) )
images = mm_data["images"] images = mm_data["images"]
parsed_images = ( parsed_images = self.data_parser.parse_mm_data({"image": images}).get_items(
self._get_data_parser() "image", ImageProcessorItems
.parse_mm_data({"image": images})
.get_items("image", ImageProcessorItems)
) )
tile_size = vision_config.image_size tile_size = vision_config.image_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment