Unverified Commit 51550179 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Define MM data parser in processing info instead of processor itself (#33260)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 07ea184f
......@@ -1860,6 +1860,12 @@ def get_frame_times_and_chosen_fps(
class Molmo2ProcessingInfo(BaseProcessingInfo):
def get_data_parser(self):
return MultiModalDataParser(
video_needs_metadata=True,
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_hf_processor(self, **kwargs: object) -> Molmo2ProcessorWrapper:
processor = self.ctx.get_hf_processor(**kwargs)
hf_config = self.ctx.get_hf_config()
......@@ -2183,9 +2189,6 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
return prompt_tokens
def _get_data_parser(self) -> MultiModalDataParser:
return MultiModalDataParser(video_needs_metadata=True)
def _call_hf_processor(
self,
prompt: str,
......
......@@ -1143,6 +1143,12 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
def supports_video(self):
return self.get_hf_processor().supports_video
def get_data_parser(self):
return MultiModalDataParser(
video_needs_metadata=True,
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_supported_mm_limits(self):
video_limit = {"video": None} if self.supports_video else {}
return {**super().get_supported_mm_limits(), **video_limit}
......@@ -1274,9 +1280,6 @@ class NanoNemotronVLMultiModalProcessor(
):
"""MultiModalProcessor extended for video support"""
def _get_data_parser(self) -> MultiModalDataParser:
return MultiModalDataParser(video_needs_metadata=True)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
......
......@@ -25,7 +25,7 @@ from vllm.multimodal.inputs import (
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
PromptReplacement,
......@@ -53,6 +53,12 @@ from .utils import (
class OpenCUAProcessingInfo(Qwen2VLProcessingInfo):
def get_data_parser(self):
return Qwen2VLMultiModalDataParser(
self.get_hf_config().vision_config.spatial_merge_size,
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_hf_config(self):
return self.ctx.get_hf_config()
......@@ -125,11 +131,6 @@ class OpenCUAProcessor(Qwen2VLProcessor):
class OpenCUAMultiModalProcessor(BaseMultiModalProcessor[OpenCUAProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2VLMultiModalDataParser(
self.info.get_hf_config().vision_config.spatial_merge_size
)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
......
......@@ -568,6 +568,15 @@ class Phi4MMProcessingInfo(BaseProcessingInfo):
def get_feature_extractor(self, **kwargs: object) -> SequenceFeatureExtractor:
return self.get_hf_processor(**kwargs).audio_processor
def get_data_parser(self):
feature_extractor = self.get_feature_extractor()
return MultiModalDataParser(
target_sr=feature_extractor.sampling_rate,
audio_resample_method="scipy",
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": None, "image": None}
......@@ -844,12 +853,6 @@ class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]):
class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return MultiModalDataParser(
target_sr=feature_extractor.sampling_rate, audio_resample_method="scipy"
)
def _call_hf_processor(
self,
prompt: str,
......
......@@ -77,7 +77,6 @@ from vllm.multimodal.parse import (
DictEmbeddingItems,
ModalityDataItems,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.multimodal.processing import BaseDummyInputsBuilder
from vllm.multimodal.processing.processor import (
......@@ -227,6 +226,16 @@ class Qwen2_5OmniThinkerProcessingInfo(
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
def get_data_parser(self):
feature_extractor = self.get_feature_extractor()
return Qwen2_5OmniThinkerMultiModalDataParser(
spatial_merge_size=self.get_hf_config().vision_config.spatial_merge_size,
target_sr=feature_extractor.sampling_rate,
target_channels=self.get_target_channels(),
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_target_channels(self) -> int:
"""Return target audio channels for Qwen2.5 Omni models (mono)."""
return 1
......@@ -310,14 +319,6 @@ class Qwen2_5OmniThinkerDummyInputsBuilder(
class Qwen2_5OmniThinkerMultiModalProcessor(
BaseMultiModalProcessor[Qwen2_5OmniThinkerProcessingInfo]
):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return Qwen2_5OmniThinkerMultiModalDataParser(
spatial_merge_size=self.info.get_hf_config().vision_config.spatial_merge_size,
target_sr=feature_extractor.sampling_rate,
target_channels=self.info.get_target_channels(),
)
def _call_hf_processor(
self,
prompt: str,
......
......@@ -127,6 +127,30 @@ def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
return feat_lengths, output_lengths
def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]):
return dict(
audio_embeds=MultiModalFieldConfig.batched("audio"),
input_features=MultiModalFieldConfig.batched("audio"),
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
)
class Qwen2AudioMultiModalDataParser(MultiModalDataParser):
def _parse_audio_data(
self,
data: dict[str, torch.Tensor] | ModalityData[AudioItem],
) -> ModalityDataItems[Any, Any] | None:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="audio",
required_fields={"audio_embeds"},
fields_factory=_qwen2audio_field_config,
)
return super()._parse_audio_data(data)
class Qwen2AudioProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Qwen2AudioConfig)
......@@ -140,6 +164,15 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo):
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
def get_data_parser(self):
feature_extractor = self.get_feature_extractor()
return Qwen2AudioMultiModalDataParser(
target_sr=feature_extractor.sampling_rate,
target_channels=self.get_target_channels(),
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_target_channels(self) -> int:
"""Return target audio channels for Qwen2 Audio models (mono)."""
return 1
......@@ -178,38 +211,7 @@ class Qwen2AudioDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2AudioProcessingIn
}
def _qwen2audio_field_config(hf_inputs: Mapping[str, torch.Tensor]):
return dict(
audio_embeds=MultiModalFieldConfig.batched("audio"),
input_features=MultiModalFieldConfig.batched("audio"),
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
)
class Qwen2AudioMultiModalDataParser(MultiModalDataParser):
def _parse_audio_data(
self,
data: dict[str, torch.Tensor] | ModalityData[AudioItem],
) -> ModalityDataItems[Any, Any] | None:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="audio",
required_fields={"audio_embeds"},
fields_factory=_qwen2audio_field_config,
)
return super()._parse_audio_data(data)
class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return Qwen2AudioMultiModalDataParser(
target_sr=feature_extractor.sampling_rate,
target_channels=self.info.get_target_channels(),
)
def _call_hf_processor(
self,
prompt: str,
......
......@@ -806,6 +806,12 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
def get_image_processor(self, **kwargs: object) -> Qwen2VLImageProcessor:
return self.get_hf_processor(**kwargs).image_processor
def get_data_parser(self):
return Qwen2VLMultiModalDataParser(
self.get_hf_config().vision_config.spatial_merge_size,
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None, "video": None}
......@@ -1039,11 +1045,6 @@ class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2VLMultiModalDataParser(
self.info.get_hf_config().vision_config.spatial_merge_size
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
......
......@@ -81,7 +81,7 @@ from vllm.multimodal.inputs import (
PlaceholderRange,
VideoItem,
)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
......@@ -624,6 +624,13 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
def get_video_processor(self, **kwargs: object) -> Qwen3VLVideoProcessor:
return self.get_hf_processor(**kwargs).video_processor
def get_data_parser(self):
return Qwen2VLMultiModalDataParser(
self.get_hf_config().vision_config.spatial_merge_size,
video_needs_metadata=True,
expected_hidden_size=self._get_expected_hidden_size(),
)
def _get_vision_info(
self,
*,
......@@ -901,12 +908,6 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
class Qwen3VLMultiModalProcessor(BaseMultiModalProcessor[Qwen3VLProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2VLMultiModalDataParser(
self.info.get_hf_config().vision_config.spatial_merge_size,
video_needs_metadata=True,
)
def _call_hf_processor(
self,
prompt: str,
......
......@@ -19,6 +19,7 @@
from collections import OrderedDict
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import Any
import torch
......@@ -38,7 +39,6 @@ from vllm.model_executor.layers.pooler import IdentityPooler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import (
ImageItem,
ModalityData,
......@@ -89,7 +89,45 @@ def _terratorch_field_factory(input_definition: InputDefinition):
return _terratorch_field_config
class TerratorchMultiModalDataParser(MultiModalDataParser):
def __init__(self, input_definition: InputDefinition, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_definition = input_definition
def _parse_image_data(
self,
data: dict[str, torch.Tensor] | ModalityData[ImageItem],
) -> ModalityDataItems[Any, Any] | None:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="image",
required_fields=_terratorch_field_names(self.input_definition),
fields_factory=_terratorch_field_factory(self.input_definition),
)
return super()._parse_image_data(data)
def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
if "image" not in mm_data:
mm_data = {"image": mm_data}
return super().parse_mm_data(mm_data)
class TerratorchProcessingInfo(BaseProcessingInfo):
@cached_property
def input_definition(self) -> InputDefinition:
pretrained_cfg = self.get_hf_config().to_dict()["pretrained_cfg"]
return InputDefinition(**pretrained_cfg["input"])
def get_data_parser(self):
return TerratorchMultiModalDataParser(
self.input_definition,
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
......@@ -123,55 +161,13 @@ class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
return self.dummy_data_generator.get_dummy_mm_data()
class TerratorchMultiModalDataParser(MultiModalDataParser):
def __init__(self, input_definition: InputDefinition, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_definition = input_definition
def _parse_image_data(
self,
data: dict[str, torch.Tensor] | ModalityData[ImageItem],
) -> ModalityDataItems[Any, Any] | None:
if isinstance(data, dict):
return DictEmbeddingItems(
data,
modality="image",
required_fields=_terratorch_field_names(self.input_definition),
fields_factory=_terratorch_field_factory(self.input_definition),
)
return super()._parse_image_data(data)
def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
if "image" not in mm_data:
mm_data = {"image": mm_data}
return super().parse_mm_data(mm_data)
class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
def __init__(
self,
info: TerratorchProcessingInfo,
dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]",
*,
cache: MultiModalProcessorOnlyCache | None = None,
) -> None:
pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"]
self._input_definition = InputDefinition(**pretrained_cfg["input"])
super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache)
def _get_data_parser(self) -> MultiModalDataParser:
return TerratorchMultiModalDataParser(self._input_definition)
class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessingInfo]):
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return _terratorch_field_factory(self._input_definition)(hf_inputs)
return _terratorch_field_factory(self.info.input_definition)(hf_inputs)
def _get_prompt_updates(
self,
......
......@@ -133,6 +133,15 @@ class UltravoxProcessingInfo(BaseProcessingInfo):
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
def get_data_parser(self):
feature_extractor = self.get_feature_extractor()
return MultiModalDataParser(
target_sr=feature_extractor.sampling_rate,
target_channels=self.get_target_channels(),
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_target_channels(self) -> int:
"""Return target audio channels for Ultravox models (mono)."""
return 1
......@@ -171,13 +180,6 @@ class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo])
class UltravoxMultiModalProcessor(BaseMultiModalProcessor[UltravoxProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return MultiModalDataParser(
target_sr=feature_extractor.sampling_rate,
target_channels=self.info.get_target_channels(),
)
def _call_hf_processor(
self,
prompt: str,
......
......@@ -203,6 +203,12 @@ class VoxtralProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self) -> VoxtralProcessorAdapter:
return VoxtralProcessorAdapter(self.get_tokenizer())
def get_data_parser(self):
return MultiModalDataParser(
target_sr=self.get_hf_processor().sampling_rate,
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": 5} # Performance tends to degrade after 5
......@@ -335,10 +341,6 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
# NOTE: The tokens are already inserted by the chat template
return prompt_ids, mm_info, True
def _get_data_parser(self) -> MultiModalDataParser:
sampling_rate = self.info.get_hf_processor().sampling_rate
return MultiModalDataParser(target_sr=sampling_rate)
@MULTIMODAL_REGISTRY.register_processor(
VoxtralMultiModalProcessor,
......
......@@ -644,6 +644,15 @@ class WhisperProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> WhisperConfig:
return self.ctx.get_hf_config(WhisperConfig)
def get_data_parser(self):
feature_extractor = self.get_feature_extractor()
return MultiModalDataParser(
target_sr=feature_extractor.sampling_rate,
target_channels=self.get_target_channels(),
expected_hidden_size=self._get_expected_hidden_size(),
)
@property
def skip_prompt_length_check(self) -> bool:
return True # Because the encoder prompt is padded
......@@ -693,13 +702,6 @@ class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return MultiModalDataParser(
target_sr=feature_extractor.sampling_rate,
target_channels=self.info.get_target_channels(),
)
def create_encoder_prompt(
self,
prompt: str | list[int],
......
......@@ -17,6 +17,7 @@ import torch
from typing_extensions import TypeVar
from vllm.logger import init_logger
from vllm.multimodal.parse import MultiModalDataParser
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
......@@ -569,6 +570,35 @@ class BaseProcessingInfo:
"""
return self.ctx.get_hf_processor(**kwargs)
def _get_expected_hidden_size(self) -> int | None:
"""
Get expected hidden size for embedding validation if `mm_embeds` are enabled.
This validates hidden dimensions to prevent a vulnerability where embeddings
with correct `ndim` but wrong `shape` could cause crashes at inference time.
"""
model_config = self.ctx.model_config
mm_config = model_config.get_multimodal_config()
if mm_config.enable_mm_embeds:
return model_config.get_inputs_embeds_size()
return None
def get_data_parser(self) -> MultiModalDataParser:
"""
Constructs a parser to preprocess multi-modal data items
before passing them to
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
You can support additional modalities by creating a subclass
of [`MultiModalDataParser`][vllm.multimodal.parse.MultiModalDataParser]
that has additional subparsers.
"""
return MultiModalDataParser(
expected_hidden_size=self._get_expected_hidden_size(),
)
@property
def skip_prompt_length_check(self) -> bool:
return False
......
......@@ -40,7 +40,6 @@ from ..parse import (
DictEmbeddingItems,
EmbeddingItems,
MultiModalDataItems,
MultiModalDataParser,
)
from .context import (
BaseProcessingInfo,
......@@ -990,7 +989,16 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self.dummy_inputs = dummy_inputs
self.cache = cache
self.data_parser = self._get_data_parser()
if hasattr(self, "_get_data_parser"):
logger.warning_once(
"BaseMultiModalProcessor._get_data_parser is deprecated "
"and will be removed in v0.16."
"You should override `info.build_data_parser` instead."
)
self.data_parser = self._get_data_parser() # type: ignore
else:
self.data_parser = self.info.get_data_parser()
# Avoid unnecessary recomputation
self._supported_mm_limits = self.info.get_supported_mm_limits()
......@@ -1014,26 +1022,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
) -> MultiModalInputs:
return self.apply(prompt, mm_data, hf_processor_mm_kwargs, mm_uuids=mm_uuids)
def _get_data_parser(self) -> MultiModalDataParser:
"""
Construct a parser to preprocess multi-modal data items
before passing them to
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
You can support additional modalities by creating a subclass
of [`MultiModalDataParser`][vllm.multimodal.parse.MultiModalDataParser]
that has additional subparsers.
"""
# Get expected hidden size for embedding validation if mm_embeds enabled
# This validates hidden dimensions to prevent vulnerabilities: embeddings
# with correct ndim but wrong shape could cause crashes at inference time
mm_config = self.info.ctx.model_config.get_multimodal_config()
expected_hidden_size = None
if mm_config.enable_mm_embeds:
expected_hidden_size = self.info.ctx.model_config.get_inputs_embeds_size()
return MultiModalDataParser(expected_hidden_size=expected_hidden_size)
def validate_num_items(
self,
modality: str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment