Unverified Commit 88c3e114 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Move MM data parsing outside processor (#33408)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 92924b2d
......@@ -227,9 +227,8 @@ class AyaVisionMultiModalProcessor(BaseMultiModalProcessor[AyaVisionProcessingIn
# HF processor pops the `num_patches` kwarg, which is needed by vLLM
if (images := mm_data.get("images")) is not None:
parsed_images = self.data_parser.parse_mm_data({"image": images}).get_items(
"image", ImageProcessorItems
)
mm_items = self.info.parse_mm_data({"image": images}, validate=False)
parsed_images = mm_items.get_items("image", ImageProcessorItems)
image_sizes = [
parsed_images.get_image_size(i) for i in range(len(parsed_images))
]
......
......@@ -201,20 +201,20 @@ class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]):
def apply(
self,
prompt: str | list[int],
mm_data: MultiModalDataDict,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs:
if prompt and mm_data:
if prompt and mm_items:
raise ValueError(
"CLIP accepts text-only or image-only inputs, not both! "
"Image-only inputs means passing an image with an empty text "
"prompt."
)
if mm_data:
if mm_items:
# For multi-modal data, the prompt after processing should
# only contain the dummy image tokens
tokenization_kwargs = {
......@@ -224,7 +224,7 @@ class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]):
return super().apply(
prompt=prompt,
mm_data=mm_data,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
......
......@@ -262,9 +262,8 @@ class Cohere2VisionMultiModalProcessor(
hf_processor = self.info.get_hf_processor(**mm_kwargs)
# Fallback calculation if HF processor didn't provide num_patches
parsed_images = self.data_parser.parse_mm_data({"image": images}).get_items(
"image", ImageProcessorItems
)
mm_items = self.info.parse_mm_data({"image": images}, validate=False)
parsed_images = mm_items.get_items("image", ImageProcessorItems)
num_patches = [
self.info.get_num_patches(
......
......@@ -290,9 +290,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
# HF processor pops the `num_crops` kwarg, which is needed by vLLM
if (images := mm_data.get("images")) is not None:
parsed_images = self.data_parser.parse_mm_data({"image": images}).get_items(
"image", ImageProcessorItems
)
mm_items = self.info.parse_mm_data({"image": images}, validate=False)
parsed_images = mm_items.get_items("image", ImageProcessorItems)
image_sizes = [
parsed_images.get_image_size(i) for i in range(len(parsed_images))
]
......
......@@ -349,9 +349,8 @@ class Idefics3MultiModalProcessor(BaseMultiModalProcessor[Idefics3ProcessingInfo
tok_kwargs,
)
parsed_images = self.data_parser.parse_mm_data({"image": images}).get_items(
"image", ImageProcessorItems
)
mm_items = self.info.parse_mm_data({"image": images}, validate=False)
parsed_images = mm_items.get_items("image", ImageProcessorItems)
image_sizes = [
parsed_images.get_image_size(i) for i in range(len(parsed_images))
]
......
......@@ -357,9 +357,8 @@ class Lfm2VLMultiModalProcessor(BaseMultiModalProcessor[Lfm2VLProcessingInfo]):
tok_kwargs,
)
parsed_images = self.data_parser.parse_mm_data({"image": images}).get_items(
"image", ImageProcessorItems
)
mm_items = self.info.parse_mm_data({"image": images}, validate=False)
parsed_images = mm_items.get_items("image", ImageProcessorItems)
image_sizes = [
parsed_images.get_image_size(i) for i in range(len(parsed_images))
]
......
......@@ -769,7 +769,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
def apply(
self,
prompt: str | list[int],
mm_data: MultiModalDataDict,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object] | None = None,
mm_uuids: MultiModalUUIDDict | None = None,
......@@ -785,13 +785,12 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
result = super().apply(
prompt,
mm_data,
mm_items,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_uuids=mm_uuids,
)
mm_items = self._to_mm_items(mm_data)
mm_item_counts = mm_items.get_all_counts()
mm_kwargs = result["mm_kwargs"]
mm_hashes = result["mm_hashes"]
......
......@@ -300,7 +300,8 @@ class MiniCPMOMultiModalProcessor(MiniCPMVMultiModalProcessor[MiniCPMOProcessing
if (audios := mm_data.get("audios")) is None:
return {}
parsed_audios = self.data_parser.parse_mm_data({"audio": audios}).get_items(
mm_items = self.info.parse_mm_data({"audio": audios}, validate=False)
parsed_audios = mm_items.get_items(
"audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems)
)
......
......@@ -767,7 +767,8 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
if (images := mm_data.get("images")) is None:
return {}
parsed_images = self.data_parser.parse_mm_data({"image": images}).get_items(
mm_items = self.info.parse_mm_data({"image": images}, validate=False)
parsed_images = mm_items.get_items(
"image", (MiniCPMVImageEmbeddingItems, ImageProcessorItems)
)
......@@ -793,7 +794,8 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
if (videos := mm_data.get("videos")) is None:
return {}
parsed_videos = self.data_parser.parse_mm_data({"video": videos}).get_items(
mm_items = self.info.parse_mm_data({"video": videos}, validate=False)
parsed_videos = mm_items.get_items(
"video", (MiniCPMVVideoEmbeddingItems, VideoProcessorItems)
)
......
......@@ -609,9 +609,8 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo])
)
images = mm_data["images"]
parsed_images = self.data_parser.parse_mm_data({"image": images}).get_items(
"image", ImageProcessorItems
)
mm_items = self.info.parse_mm_data({"image": images}, validate=False)
parsed_images = mm_items.get_items("image", ImageProcessorItems)
tile_size = vision_config.image_size
possible_resolutions = find_supported_resolutions(
......
......@@ -660,7 +660,7 @@ class NemotronParseMultiModalProcessor(
def create_encoder_prompt(
self,
prompt: str | list[int],
mm_data: MultiModalDataDict,
mm_items: MultiModalDataItems,
) -> str | list[int]:
return [0]
......
......@@ -225,14 +225,14 @@ class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingIn
def apply(
self,
prompt: str | list[int],
mm_data: MultiModalDataDict,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object] | None = None,
mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs:
mm_inputs = super().apply(
prompt,
mm_data,
mm_items,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_uuids=mm_uuids,
......
......@@ -303,9 +303,11 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
res = tokenizer.mistral.encode_chat_completion(request)
dummy_tokens = res.tokens
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data)
return ProcessorInputs(
prompt=dummy_tokens,
mm_data=dummy_mm_data,
mm_items=dummy_mm_items,
tokenization_kwargs=tokenization_kwargs,
)
......
......@@ -187,20 +187,20 @@ class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]):
def apply(
self,
prompt: str | list[int],
mm_data: MultiModalDataDict,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object] | None = None,
*,
mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs:
if prompt and mm_data:
if prompt and mm_items:
raise ValueError(
"Siglip accepts text-only or image-only inputs, not both! "
"Image-only inputs means passing an image with an empty text "
"prompt."
)
if mm_data:
if mm_items:
# For multi-modal data, the prompt after processing should
# only contain the image token
tokenization_kwargs = {
......@@ -210,7 +210,7 @@ class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]):
return super().apply(
prompt=prompt,
mm_data=mm_data,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
......
......@@ -180,20 +180,20 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing
def apply(
self,
prompt: str | list[int],
mm_data: MultiModalDataDict,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object] | None = None,
mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs:
mm_items = self._to_mm_items(mm_data)
tokenization_kwargs = tokenization_kwargs or {}
if tokenization_kwargs is None:
tokenization_kwargs = {}
mm_hashes = self._hash_mm_items(
mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids
)
mm_processed_data = BatchFeature(
mm_data.get("image", mm_data), tensor_type="pt"
)
_, passthrough_data = self._get_hf_mm_data(mm_items)
mm_processed_data = BatchFeature(dict(passthrough_data), tensor_type="pt")
mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]}
mm_kwargs = MultiModalKwargsItems.from_hf_inputs(
......
......@@ -174,7 +174,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
def apply(
self,
prompt: str | list[int],
mm_data: MultiModalDataDict,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object] | None = None,
mm_uuids: MultiModalUUIDDict | None = None,
......@@ -188,7 +188,6 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
if tokenization_kwargs is None:
tokenization_kwargs = {}
mm_items = self._to_mm_items(mm_data)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if not isinstance(prompt, str):
# the prompt is the tokenized ids which is not supported
......
......@@ -262,11 +262,14 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
)
res = tokenizer.mistral.encode_chat_completion(request)
dummy_tokens = res.tokens
# whixtral tokenizer adds padding to the audio
# so we need to update the audio arrays
dummy_mm_data["audio"] = [a.audio_array for a in res.audios]
return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data)
dummy_mm_inputs = self.info.parse_mm_data(
# whixtral tokenizer adds padding to the audio
# so we need to update the audio arrays
{**dummy_mm_data, "audio": [a.audio_array for a in res.audios]},
)
return ProcessorInputs(prompt=dummy_tokens, mm_items=dummy_mm_inputs)
class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]):
......
......@@ -705,7 +705,7 @@ class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo
def create_encoder_prompt(
self,
prompt: str | list[int],
mm_data: MultiModalDataDict,
mm_items: MultiModalDataItems,
) -> str | list[int]:
# Strictly speaking, whisper encoder only accept audio features.
# We create a dummy encoder prompt here which will be padded to
......
......@@ -14,7 +14,13 @@ import torch
from typing_extensions import TypeVar
from vllm.logger import init_logger
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.inputs import MultiModalDataDict
from vllm.multimodal.parse import (
DictEmbeddingItems,
EmbeddingItems,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
......@@ -596,6 +602,10 @@ class BaseProcessingInfo:
expected_hidden_size=self._get_expected_hidden_size(),
)
@cached_property
def data_parser(self) -> MultiModalDataParser:
return self.get_data_parser()
@property
def skip_prompt_length_check(self) -> bool:
return False
......@@ -655,6 +665,36 @@ class BaseProcessingInfo:
raise ValueError(msg)
def parse_mm_data(
self,
mm_data: MultiModalDataDict,
*,
validate: bool = True,
) -> MultiModalDataItems:
"""
Normalize
[`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict]
to [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems]
before passing them to
[`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data].
"""
mm_items = self.data_parser.parse_mm_data(mm_data)
if validate:
mm_config = self.ctx.model_config.get_multimodal_config()
if not mm_config.enable_mm_embeds:
for modality, items in mm_items.items():
if isinstance(items, (EmbeddingItems, DictEmbeddingItems)):
raise ValueError(
f"You must set `--enable-mm-embeds` to input "
f"`{modality}_embeds`"
)
for modality, items in mm_items.items():
self.validate_num_items(modality, len(items))
return mm_items
def get_mm_max_tokens_per_item(
self,
seq_len: int,
......
......@@ -18,6 +18,7 @@ from vllm.config.multimodal import (
from vllm.logger import init_logger
from ..inputs import MultiModalDataDict
from ..parse import MultiModalDataItems
from .context import BaseProcessingInfo
_I = TypeVar("_I", bound=BaseProcessingInfo)
......@@ -33,7 +34,7 @@ class ProcessorInputs:
"""
prompt: str | list[int]
mm_data: MultiModalDataDict
mm_items: MultiModalDataItems
hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict)
tokenization_kwargs: Mapping[str, object] = field(default_factory=dict)
......@@ -93,15 +94,14 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
mm_options: Configurable options per modality (optional)
"""
dummy_text = self.get_dummy_text(mm_counts)
# Use the unified function for both legacy and configurable cases
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
dummy_mm_items = self.info.parse_mm_data(dummy_mm_data)
tokenization_kwargs = {"truncation": False}
return ProcessorInputs(
prompt=dummy_text,
mm_data=dummy_mm_data,
mm_items=dummy_mm_items,
tokenization_kwargs=tokenization_kwargs,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment