Unverified Commit 987506bc authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Simplify dummy data generation (#35025)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent c645e9a2
......@@ -293,21 +293,22 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_options: Mapping[str, BaseDummyOptions],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = \
self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
image_overrides = mm_options.get("image")
return {
"image":
self._get_dummy_images(width=target_width,
"image": self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
overrides=image_overrides)
overrides=image_overrides,
)
}
```
......@@ -479,17 +480,16 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
mm_options: Mapping[str, BaseDummyOptions],
) -> MultiModalDataDict:
target_width, target_height = \
self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
image_overrides = mm_options.get("image") if mm_options else None
image_overrides = mm_options.get("image")
return {
"image":
self._get_dummy_images(
"image": self._get_dummy_images(
width=target_width,
height=target_height,
num_images=num_images,
......
......@@ -116,7 +116,7 @@ def test_dummy_data_generation(mock_ctx):
builder = AudioFlamingo3DummyInputsBuilder(info)
mm_counts = {"audio": 2}
dummy_data = builder.get_dummy_mm_data(100, mm_counts, None)
dummy_data = builder.get_dummy_mm_data(100, mm_counts, {})
assert "audio" in dummy_data
assert len(dummy_data["audio"]) == 2
......
......@@ -195,6 +195,7 @@ def get_text_token_prompts(
inputs = dummy_inputs.get_dummy_processor_inputs(
model_config.max_model_len,
mm_counts,
mm_options={},
)
text_prompt = None
token_prompt = (
......@@ -224,6 +225,7 @@ def get_text_token_prompts(
inputs = dummy_inputs.get_dummy_processor_inputs(
model_config.max_model_len,
mm_counts,
mm_options={},
)
assert isinstance(inputs.prompt, str)
......
......@@ -97,6 +97,7 @@ def create_batched_mm_kwargs(
processor_inputs = dummy_inputs.get_dummy_processor_inputs(
seq_len=model_config.max_model_len,
mm_counts=mm_counts,
mm_options={},
)
mm_items = processor_inputs.mm_items
resized_mm_data = {
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from typing import Any, Literal, TypeAlias
from typing import Any, Literal, TypeAlias, TypedDict, final
from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
......@@ -43,11 +43,29 @@ class AudioDummyOptions(BaseDummyOptions):
length: int | None = Field(None, gt=0)
@final
class MultiModalDummyOptionsBuiltins(TypedDict, total=False):
"""Type annotations for modality types predefined by vLLM."""
image: ImageDummyOptions
"""Options for dummy images."""
video: VideoDummyOptions
"""Options for dummy videos."""
audio: AudioDummyOptions
"""Options for dummy audios."""
MMEncoderTPMode = Literal["weights", "data"]
MMCacheType = Literal["shm", "lru"]
DummyOptions: TypeAlias = (
BaseDummyOptions | VideoDummyOptions | ImageDummyOptions | AudioDummyOptions
)
MMDummyOptions: TypeAlias = dict[str, BaseDummyOptions]
"""
A dictionary containing an entry for each modality type of dummy data.
The built-in modalities are defined by
[`MultiModalDummyOptionsBuiltins`][vllm.config.multimodal.MultiModalDummyOptionsBuiltins].
"""
@config
......@@ -57,7 +75,7 @@ class MultiModalConfig:
language_model_only: bool = False
"""If True, disables all multimodal inputs by setting all modality limits to 0.
Equivalent to setting `--limit-mm-per-prompt` to 0 for every modality."""
limit_per_prompt: dict[str, DummyOptions] = Field(default_factory=dict)
limit_per_prompt: MMDummyOptions = Field(default_factory=dict)
"""The maximum number of input items and options allowed per
prompt for each modality.
......@@ -158,22 +176,27 @@ class MultiModalConfig:
@field_validator("limit_per_prompt", mode="before")
@classmethod
def _validate_limit_per_prompt(
cls, value: dict[str, int | dict[str, int]]
) -> dict[str, DummyOptions]:
cls,
value: dict[str, int | dict[str, int]],
) -> MMDummyOptions:
out: MMDummyOptions = {}
for k, v in value.items():
# Handle legacy format where only count is specified
if isinstance(v, int):
v = {"count": v}
# Convert to the appropriate DummyOptions subclass
if k == "video":
value[k] = VideoDummyOptions(**v)
out[k] = VideoDummyOptions(**v)
elif k == "image":
value[k] = ImageDummyOptions(**v)
out[k] = ImageDummyOptions(**v)
elif k == "audio":
value[k] = AudioDummyOptions(**v)
out[k] = AudioDummyOptions(**v)
else:
value[k] = BaseDummyOptions(**v)
return value
out[k] = BaseDummyOptions(**v)
return out
@field_validator("mm_encoder_attn_backend", mode="before")
@classmethod
......@@ -240,15 +263,8 @@ class MultiModalConfig:
if limit_data is None:
# Unspecified modality is set to 999 by default
return 999
return limit_data.count
def get_dummy_options(self, modality: str) -> BaseDummyOptions | None:
"""
Get the configurable dummy data options for a modality.
Returns None if no options are configured for this modality.
"""
# All values are now DummyOptions after normalization
return self.limit_per_prompt.get(modality)
return limit_data.count
def merge_mm_processor_kwargs(
self,
......
......@@ -444,15 +444,14 @@ class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]):
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
mm_options: Mapping[str, BaseDummyOptions],
) -> MultiModalDataDict:
vision_config = self.info.get_vision_config()
max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)
image_overrides = mm_options.get("image") if mm_options else None
image_overrides = mm_options.get("image")
return {
"image": self._get_dummy_images(
......
......@@ -252,16 +252,13 @@ class AudioFlamingo3DummyInputsBuilder(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
mm_options: Mapping[str, BaseDummyOptions],
) -> MultiModalDataDict:
feature_extractor = self.info.get_feature_extractor(
**(mm_processor_kwargs or {})
)
feature_extractor = self.info.get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = MAX_AUDIO_LEN * sampling_rate
num_audios = mm_counts.get("audio", 0)
audio_overrides = mm_options.get("audio") if mm_options else None
audio_overrides = mm_options.get("audio")
return {
"audio": self._get_dummy_audios(
......
......@@ -191,13 +191,12 @@ class AyaVisionDummyInputsBuilder(BaseDummyInputsBuilder[AyaVisionProcessingInfo
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
mm_options: Mapping[str, BaseDummyOptions],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
image_size = self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
image_overrides = mm_options.get("image")
return {
"image": self._get_dummy_images(
......
......@@ -249,8 +249,7 @@ class BagelDummyInputsBuilder(BaseDummyInputsBuilder[BagelProcessingInfo]):
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
mm_options: Mapping[str, BaseDummyOptions],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
hf_config = self.info.get_hf_config()
......@@ -258,7 +257,7 @@ class BagelDummyInputsBuilder(BaseDummyInputsBuilder[BagelProcessingInfo]):
# Use the configured image size
image_size = vit_config.image_size
image_overrides = mm_options.get("image") if mm_options else None
image_overrides = mm_options.get("image")
return {
"image": self._get_dummy_images(
......
......@@ -90,14 +90,13 @@ class BeeDummyInputsBuilder(LlavaDummyInputsBuilder[BeeProcessingInfo]):
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
mm_options: Mapping[str, BaseDummyOptions],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
image_overrides = mm_options.get("image")
return {
"image": self._get_dummy_images(
......
......@@ -445,8 +445,7 @@ class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]):
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
mm_options: Mapping[str, BaseDummyOptions],
) -> MultiModalDataDict:
hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config
......@@ -454,7 +453,7 @@ class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]):
max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0)
image_overrides = mm_options.get("image") if mm_options else None
image_overrides = mm_options.get("image")
return {
"image": self._get_dummy_images(
......
......@@ -116,15 +116,14 @@ class ChameleonDummyInputsBuilder(BaseDummyInputsBuilder[ChameleonProcessingInfo
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
mm_options: Mapping[str, BaseDummyOptions],
) -> MultiModalDataDict:
config = self.info.get_hf_config()
width = height = config.vq_config.resolution
num_images = mm_counts.get("image", 0)
image_overrides = mm_options.get("image") if mm_options else None
image_overrides = mm_options.get("image")
return {
"image": self._get_dummy_images(
......
......@@ -174,14 +174,13 @@ class CLIPDummyInputsBuilder(BaseDummyInputsBuilder[CLIPProcessingInfo]):
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
mm_options: Mapping[str, BaseDummyOptions],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
image_overrides = mm_options.get("image")
return {
"image": self._get_dummy_images(
......
......@@ -197,13 +197,12 @@ class Cohere2VisionDummyInputsBuilder(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
mm_options: Mapping[str, BaseDummyOptions],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
image_size = self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
image_overrides = mm_options.get("image")
return {
"image": self._get_dummy_images(
......
......@@ -132,12 +132,12 @@ class ColModernVBertDummyInputsBuilder(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
mm_options: Mapping[str, BaseDummyOptions],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
image_overrides = mm_options.get("image")
return {
"image": self._get_dummy_images(
width=target_width,
......
......@@ -255,8 +255,7 @@ class DeepseekOCRDummyInputsBuilder(BaseDummyInputsBuilder[DeepseekOCRProcessing
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
mm_options: Mapping[str, BaseDummyOptions],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
......
......@@ -137,8 +137,7 @@ class DeepseekOCR2DummyInputsBuilder(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
mm_options: Mapping[str, BaseDummyOptions],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
......
......@@ -214,14 +214,13 @@ class DeepseekVL2DummyInputsBuilder(BaseDummyInputsBuilder[DeepseekVL2Processing
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
mm_options: Mapping[str, BaseDummyOptions],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
max_image_size = self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
image_overrides = mm_options.get("image")
return {
"image": self._get_dummy_images(
......
......@@ -106,17 +106,13 @@ class DotsOCRDummyInputsBuilder(Qwen2VLDummyInputsBuilder):
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
mm_options: Mapping[str, BaseDummyOptions],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
mm_processor_kwargs = mm_processor_kwargs or {}
target_width, target_height = self.info.get_image_size_with_most_features( # noqa: E501
mm_processor_kwargs.get("max_pixels", None)
)
target_width, target_height = self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
image_overrides = mm_options.get("image")
return {
"image": self._get_dummy_images(
......
......@@ -1168,8 +1168,7 @@ class Ernie4_5_VLDummyInputsBuilder(BaseDummyInputsBuilder[Ernie4_5_VLProcessing
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
mm_processor_kwargs: Mapping[str, object] | None = None,
mm_options: Mapping[str, BaseDummyOptions],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
......@@ -1179,8 +1178,8 @@ class Ernie4_5_VLDummyInputsBuilder(BaseDummyInputsBuilder[Ernie4_5_VLProcessing
seq_len, mm_counts
)
image_overrides = mm_options.get("image") if mm_options else None
video_overrides = mm_options.get("video") if mm_options else None
image_overrides = mm_options.get("image")
video_overrides = mm_options.get("video")
return {
"image": self._get_dummy_images(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment