Unverified Commit 79aa2446 authored by Wenlong Wang's avatar Wenlong Wang Committed by GitHub
Browse files

[Multi Modal] Configurable MM Profiling (#25631)


Signed-off-by: default avatarwwl2755 <wangwenlong2755@gmail.com>
Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 2ed3f20d
...@@ -258,17 +258,21 @@ Assuming that the memory usage increases with the number of tokens, the dummy in ...@@ -258,17 +258,21 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images) num_images=num_images,
overrides=image_overrides)
} }
``` ```
...@@ -438,16 +442,20 @@ Assuming that the memory usage increases with the number of tokens, the dummy in ...@@ -438,16 +442,20 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
image_overrides = mm_options.get("image") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images) num_images=num_images,
overrides=image_overrides)
} }
``` ```
......
...@@ -12,6 +12,8 @@ from mistral_common.protocol.instruct.request import ChatCompletionRequest ...@@ -12,6 +12,8 @@ from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image from PIL import Image
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.config.multimodal import (AudioDummyOptions, BaseDummyOptions,
ImageDummyOptions, VideoDummyOptions)
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.cache import MultiModalProcessorOnlyCache from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import MultiModalInputs from vllm.multimodal.inputs import MultiModalInputs
...@@ -112,12 +114,26 @@ def _test_processing_correctness( ...@@ -112,12 +114,26 @@ def _test_processing_correctness(
processing_info = factories.info(ctx) processing_info = factories.info(ctx)
supported_mm_limits = processing_info.get_supported_mm_limits() supported_mm_limits = processing_info.get_supported_mm_limits()
limit_mm_per_prompt = { # Keep integer limits for local data generation
limit_mm_per_prompt_ints = {
modality: 3 if limit is None else limit modality: 3 if limit is None else limit
for modality, limit in supported_mm_limits.items() for modality, limit in supported_mm_limits.items()
} }
model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt def _to_dummy_options(modality: str, count: int) -> BaseDummyOptions:
if modality == "video":
return VideoDummyOptions(count=count)
if modality == "image":
return ImageDummyOptions(count=count)
if modality == "audio":
return AudioDummyOptions(count=count)
return BaseDummyOptions(count=count)
# Assign normalized DummyOptions to the model config
model_config.get_multimodal_config().limit_per_prompt = {
modality: _to_dummy_options(modality, count)
for modality, count in limit_mm_per_prompt_ints.items()
}
baseline_processor = factories.build_processor(ctx, cache=None) baseline_processor = factories.build_processor(ctx, cache=None)
cached_processor = factories.build_processor(ctx, cache=cache) cached_processor = factories.build_processor(ctx, cache=cache)
...@@ -150,7 +166,7 @@ def _test_processing_correctness( ...@@ -150,7 +166,7 @@ def _test_processing_correctness(
k: k:
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]()) [(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
for _ in range(rng.randint(limit + 1))] for _ in range(rng.randint(limit + 1))]
for k, limit in limit_mm_per_prompt.items() for k, limit in limit_mm_per_prompt_ints.items()
} }
mm_counts = {k: len(vs) for k, vs in mm_data.items()} mm_counts = {k: len(vs) for k, vs in mm_data.items()}
......
...@@ -17,23 +17,23 @@ def test_profiling(model_id: str, max_model_len: int): ...@@ -17,23 +17,23 @@ def test_profiling(model_id: str, max_model_len: int):
model_config_kwargs = { model_config_kwargs = {
"max_model_len": max_model_len, "max_model_len": max_model_len,
} }
mm_counts = {"image": 1}
ctx = build_model_context( ctx = build_model_context(
model_id, model_id,
model_config_kwargs=model_config_kwargs, model_config_kwargs=model_config_kwargs,
limit_mm_per_prompt={"image": 1}, limit_mm_per_prompt=mm_counts,
) )
mm_config = ctx.get_mm_config()
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
profiler = MultiModalProfiler(processor) profiler = MultiModalProfiler(processor)
decoder_dummy_data = profiler.get_decoder_dummy_data( decoder_dummy_data = profiler.get_decoder_dummy_data(
max_model_len, max_model_len,
mm_counts=mm_config.limit_per_prompt, mm_counts=mm_counts,
) )
dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs( dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs(
max_model_len, max_model_len,
mm_counts=mm_config.limit_per_prompt, mm_counts=mm_counts,
) )
hf_config = ctx.get_hf_config(Llama4Config) hf_config = ctx.get_hf_config(Llama4Config)
...@@ -58,7 +58,7 @@ def test_profiling(model_id: str, max_model_len: int): ...@@ -58,7 +58,7 @@ def test_profiling(model_id: str, max_model_len: int):
profiled_tokens = profiler.get_mm_max_contiguous_tokens( profiled_tokens = profiler.get_mm_max_contiguous_tokens(
max_model_len, max_model_len,
mm_counts=mm_config.limit_per_prompt, mm_counts=mm_counts,
) )
assert total_tokens == profiled_tokens["image"] assert total_tokens == profiled_tokens["image"]
......
...@@ -15,6 +15,8 @@ from mistral_common.protocol.instruct.request import ChatCompletionRequest ...@@ -15,6 +15,8 @@ from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image from PIL import Image
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.config.multimodal import (AudioDummyOptions, BaseDummyOptions,
ImageDummyOptions, VideoDummyOptions)
from vllm.distributed import (cleanup_dist_env_and_memory, from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment, init_distributed_environment,
initialize_model_parallel) initialize_model_parallel)
...@@ -236,7 +238,20 @@ def test_model_tensor_schema(model_arch: str, model_id: str): ...@@ -236,7 +238,20 @@ def test_model_tensor_schema(model_arch: str, model_id: str):
modality: 3 if limit is None else limit modality: 3 if limit is None else limit
for modality, limit in supported_mm_limits.items() for modality, limit in supported_mm_limits.items()
} }
model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt
def _to_dummy_options(modality: str, count: int) -> BaseDummyOptions:
if modality == "video":
return VideoDummyOptions(count=count)
if modality == "image":
return ImageDummyOptions(count=count)
if modality == "audio":
return AudioDummyOptions(count=count)
return BaseDummyOptions(count=count)
model_config.get_multimodal_config().limit_per_prompt = {
modality: _to_dummy_options(modality, count)
for modality, count in limit_mm_per_prompt.items()
}
processor = factories.build_processor(ctx, cache=None) processor = factories.build_processor(ctx, cache=None)
with initialize_dummy_model(model_cls, model_config) as model: with initialize_dummy_model(model_cls, model_config) as model:
......
...@@ -276,7 +276,9 @@ class ModelConfig: ...@@ -276,7 +276,9 @@ class ModelConfig:
multimodal_config: Optional[MultiModalConfig] = None multimodal_config: Optional[MultiModalConfig] = None
"""Configuration for multimodal model. If `None`, this will be inferred """Configuration for multimodal model. If `None`, this will be inferred
from the architecture of `self.model`.""" from the architecture of `self.model`."""
limit_mm_per_prompt: InitVar[Optional[dict[str, int]]] = None limit_mm_per_prompt: InitVar[Optional[dict[str, Union[int,
dict[str,
int]]]]] = None
media_io_kwargs: InitVar[Optional[dict[str, dict[str, Any]]]] = None media_io_kwargs: InitVar[Optional[dict[str, dict[str, Any]]]] = None
mm_processor_kwargs: InitVar[Optional[dict[str, Any]]] = None mm_processor_kwargs: InitVar[Optional[dict[str, Any]]] = None
mm_processor_cache_gb: InitVar[Optional[float]] = None mm_processor_cache_gb: InitVar[Optional[float]] = None
......
...@@ -4,15 +4,45 @@ ...@@ -4,15 +4,45 @@
import hashlib import hashlib
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import field from dataclasses import field
from typing import Any, Literal, Optional from typing import Any, Literal, Optional, Union
from pydantic import ConfigDict, Field, field_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
import vllm.envs as envs
from vllm.config.utils import config from vllm.config.utils import config
@dataclass
class BaseDummyOptions:
"""Base options for generating dummy data during profiling."""
count: int = Field(999, ge=0)
@dataclass(config=ConfigDict(extra="forbid"))
class VideoDummyOptions(BaseDummyOptions):
"""Options for generating dummy video data during profiling."""
num_frames: Optional[int] = Field(None, gt=0)
width: Optional[int] = Field(None, gt=0)
height: Optional[int] = Field(None, gt=0)
@dataclass(config=ConfigDict(extra="forbid"))
class ImageDummyOptions(BaseDummyOptions):
"""Options for generating dummy image data during profiling."""
width: Optional[int] = Field(None, gt=0)
height: Optional[int] = Field(None, gt=0)
@dataclass(config=ConfigDict(extra="forbid"))
class AudioDummyOptions(BaseDummyOptions):
"""Options for generating dummy audio data during profiling."""
length: Optional[int] = Field(None, gt=0)
MMEncoderTPMode = Literal["weights", "data"] MMEncoderTPMode = Literal["weights", "data"]
MMCacheType = Literal["shm", "lru"] MMCacheType = Literal["shm", "lru"]
DummyOptions = Union[BaseDummyOptions, VideoDummyOptions, ImageDummyOptions,
AudioDummyOptions]
@config @config
...@@ -20,12 +50,22 @@ MMCacheType = Literal["shm", "lru"] ...@@ -20,12 +50,22 @@ MMCacheType = Literal["shm", "lru"]
class MultiModalConfig: class MultiModalConfig:
"""Controls the behavior of multimodal models.""" """Controls the behavior of multimodal models."""
limit_per_prompt: dict[str, int] = field(default_factory=dict) limit_per_prompt: dict[str, DummyOptions] = field(default_factory=dict)
"""The maximum number of input items allowed per prompt for each modality. """The maximum number of input items and options allowed per
Defaults to 1 (V0) or 999 (V1) for each modality. prompt for each modality.
Defaults to 999 for each modality.
Legacy format (count only):
{"image": 16, "video": 2}
Configurable format (with options):
{"video": {"count": 1, "num_frames": 32, "width": 512, "height": 512},
"image": {"count": 5, "width": 512, "height": 512}}
For example, to allow up to 16 images and 2 videos per prompt: Mixed format (combining both):
`{"image": 16, "video": 2}`""" {"image": 16, "video": {"count": 1, "num_frames": 32, "width": 512,
"height": 512}}
"""
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
"""Additional args passed to process media inputs, keyed by modalities. """Additional args passed to process media inputs, keyed by modalities.
For example, to set num_frames for video, set For example, to set num_frames for video, set
...@@ -84,6 +124,27 @@ class MultiModalConfig: ...@@ -84,6 +124,27 @@ class MultiModalConfig:
from each video to be pruned. from each video to be pruned.
""" """
@field_validator("limit_per_prompt", mode="before")
@classmethod
def _validate_limit_per_prompt(
cls, value: dict[str, Union[int,
dict[str,
int]]]) -> dict[str, DummyOptions]:
for k, v in value.items():
# Handle legacy format where only count is specified
if isinstance(v, int):
v = {"count": v}
# Convert to the appropriate DummyOptions subclass
if k == "video":
value[k] = VideoDummyOptions(**v)
elif k == "image":
value[k] = ImageDummyOptions(**v)
elif k == "audio":
value[k] = AudioDummyOptions(**v)
else:
value[k] = BaseDummyOptions(**v)
return value
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,
...@@ -106,12 +167,22 @@ class MultiModalConfig: ...@@ -106,12 +167,22 @@ class MultiModalConfig:
def get_limit_per_prompt(self, modality: str) -> int: def get_limit_per_prompt(self, modality: str) -> int:
""" """
Get the maximum number of input items allowed per prompt Get the maximum number of input items allowed per prompt
for the given modality. for the given modality (backward compatible).
"""
limit_data = self.limit_per_prompt.get(modality)
if limit_data is None:
# Unspecified modality is set to 999 by default
return 999
return limit_data.count
def get_dummy_options(self, modality: str) -> Optional[BaseDummyOptions]:
"""
Get the configurable dummy data options for a modality.
Returns None if no options are configured for this modality.
""" """
return self.limit_per_prompt.get( # All values are now DummyOptions after normalization
modality, return self.limit_per_prompt.get(modality)
999 if envs.VLLM_USE_V1 else 1,
)
def merge_mm_processor_kwargs( def merge_mm_processor_kwargs(
self, self,
......
...@@ -376,7 +376,7 @@ class EngineArgs: ...@@ -376,7 +376,7 @@ class EngineArgs:
quantization: Optional[QuantizationMethods] = ModelConfig.quantization quantization: Optional[QuantizationMethods] = ModelConfig.quantization
enforce_eager: bool = ModelConfig.enforce_eager enforce_eager: bool = ModelConfig.enforce_eager
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
limit_mm_per_prompt: dict[str, int] = \ limit_mm_per_prompt: dict[str, Union[int, dict[str, int]]] = \
get_field(MultiModalConfig, "limit_per_prompt") get_field(MultiModalConfig, "limit_per_prompt")
interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings
media_io_kwargs: dict[str, dict[str, media_io_kwargs: dict[str, dict[str,
......
...@@ -10,6 +10,7 @@ from transformers.models.aria.modeling_aria import AriaCrossAttention ...@@ -10,6 +10,7 @@ from transformers.models.aria.modeling_aria import AriaCrossAttention
from transformers.models.aria.processing_aria import AriaProcessor from transformers.models.aria.processing_aria import AriaProcessor
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
...@@ -431,17 +432,21 @@ class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]): ...@@ -431,17 +432,21 @@ class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
vision_config = self.info.get_vision_config() vision_config = self.info.get_vision_config()
max_image_size = vision_config.image_size max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
image_overrides = mm_options.get("image") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=max_image_size, self._get_dummy_images(width=max_image_size,
height=max_image_size, height=max_image_size,
num_images=num_images) num_images=num_images,
overrides=image_overrides)
} }
......
...@@ -16,6 +16,7 @@ from transformers.models.got_ocr2.image_processing_got_ocr2 import ( ...@@ -16,6 +16,7 @@ from transformers.models.got_ocr2.image_processing_got_ocr2 import (
get_optimal_tiled_canvas) get_optimal_tiled_canvas)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargsItems
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
...@@ -166,16 +167,20 @@ class AyaVisionDummyInputsBuilder( ...@@ -166,16 +167,20 @@ class AyaVisionDummyInputsBuilder(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
image_size = \ image_size = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=image_size.width, self._get_dummy_images(width=image_size.width,
height=image_size.height, height=image_size.height,
num_images=num_images) num_images=num_images,
overrides=image_overrides)
} }
......
...@@ -10,6 +10,7 @@ from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig, ...@@ -10,6 +10,7 @@ from transformers import (BatchFeature, Blip2Config, Blip2QFormerConfig,
apply_chunking_to_forward) apply_chunking_to_forward)
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -435,6 +436,7 @@ class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]): ...@@ -435,6 +436,7 @@ class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
...@@ -442,11 +444,14 @@ class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]): ...@@ -442,11 +444,14 @@ class Blip2DummyInputsBuilder(BaseDummyInputsBuilder[Blip2ProcessingInfo]):
max_image_size = vision_config.image_size max_image_size = vision_config.image_size
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
image_overrides = mm_options.get("image") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=max_image_size, self._get_dummy_images(width=max_image_size,
height=max_image_size, height=max_image_size,
num_images=num_images) num_images=num_images,
overrides=image_overrides)
} }
......
...@@ -14,6 +14,7 @@ from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor, ...@@ -14,6 +14,7 @@ from transformers import (BatchFeature, ChameleonConfig, ChameleonProcessor,
from vllm.attention import Attention from vllm.attention import Attention
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -92,17 +93,21 @@ class ChameleonDummyInputsBuilder( ...@@ -92,17 +93,21 @@ class ChameleonDummyInputsBuilder(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
config = self.info.get_hf_config() config = self.info.get_hf_config()
width = height = config.vq_config.resolution width = height = config.vq_config.resolution
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
image_overrides = mm_options.get("image") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=width, self._get_dummy_images(width=width,
height=height, height=height,
num_images=num_images) num_images=num_images,
overrides=image_overrides)
} }
......
...@@ -16,6 +16,7 @@ from transformers.models.cohere2_vision.processing_cohere2_vision import ( ...@@ -16,6 +16,7 @@ from transformers.models.cohere2_vision.processing_cohere2_vision import (
Cohere2VisionProcessor) Cohere2VisionProcessor)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.activation import MulAndSilu from vllm.model_executor.layers.activation import MulAndSilu
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
...@@ -209,16 +210,20 @@ class Cohere2VisionDummyInputsBuilder( ...@@ -209,16 +210,20 @@ class Cohere2VisionDummyInputsBuilder(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
image_size = \ image_size = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=image_size.width, self._get_dummy_images(width=image_size.width,
height=image_size.height, height=image_size.height,
num_images=num_images) num_images=num_images,
overrides=image_overrides)
} }
......
...@@ -14,6 +14,7 @@ from einops import rearrange, repeat ...@@ -14,6 +14,7 @@ from einops import rearrange, repeat
from transformers import BatchFeature from transformers import BatchFeature
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.utils import set_default_torch_dtype
...@@ -191,16 +192,20 @@ class DeepseekVL2DummyInputsBuilder( ...@@ -191,16 +192,20 @@ class DeepseekVL2DummyInputsBuilder(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
max_image_size = self.info.get_image_size_with_most_features() max_image_size = self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=max_image_size.width, self._get_dummy_images(width=max_image_size.width,
height=max_image_size.height, height=max_image_size.height,
num_images=num_images) num_images=num_images,
overrides=image_overrides)
} }
......
...@@ -13,6 +13,7 @@ from vllm.attention.backends.registry import _Backend ...@@ -13,6 +13,7 @@ from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import (check_upstream_fa_availability, from vllm.attention.layer import (check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend) maybe_get_vit_flash_attn_backend)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
...@@ -91,17 +92,21 @@ class DotsOCRDummyInputsBuilder(Qwen2VLDummyInputsBuilder): ...@@ -91,17 +92,21 @@ class DotsOCRDummyInputsBuilder(Qwen2VLDummyInputsBuilder):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
target_width, target_height = self.info.get_image_size_with_most_features( # noqa: E501 target_width, target_height = self.info.get_image_size_with_most_features( # noqa: E501
) )
image_overrides = mm_options.get("image") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images), num_images=num_images,
overrides=image_overrides),
} }
......
...@@ -38,6 +38,7 @@ from vllm.attention.backends.registry import _Backend ...@@ -38,6 +38,7 @@ from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import (check_upstream_fa_availability, from vllm.attention.layer import (check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend) maybe_get_vit_flash_attn_backend)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -1184,6 +1185,7 @@ class Ernie4_5_VLDummyInputsBuilder( ...@@ -1184,6 +1185,7 @@ class Ernie4_5_VLDummyInputsBuilder(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0) num_videos = mm_counts.get("video", 0)
...@@ -1193,16 +1195,21 @@ class Ernie4_5_VLDummyInputsBuilder( ...@@ -1193,16 +1195,21 @@ class Ernie4_5_VLDummyInputsBuilder(
target_num_frames = \ target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len, mm_counts) self.info.get_num_frames_with_most_features(seq_len, mm_counts)
image_overrides = mm_options.get("image") if mm_options else None
video_overrides = mm_options.get("video") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images), num_images=num_images,
overrides=image_overrides),
"video": "video":
self._get_dummy_videos(width=target_width, self._get_dummy_videos(width=target_width,
height=target_height, height=target_height,
num_frames=target_num_frames, num_frames=target_num_frames,
num_videos=num_videos) num_videos=num_videos,
overrides=video_overrides)
} }
......
...@@ -27,6 +27,7 @@ from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor, ...@@ -27,6 +27,7 @@ from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
FuyuProcessor) FuyuProcessor)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -136,16 +137,20 @@ class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]): ...@@ -136,16 +137,20 @@ class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
image_overrides = mm_options.get("image") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images) num_images=num_images,
overrides=image_overrides)
} }
......
...@@ -11,6 +11,7 @@ from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs ...@@ -11,6 +11,7 @@ from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
import vllm.envs as envs import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
...@@ -241,17 +242,21 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]): ...@@ -241,17 +242,21 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images) num_images=num_images,
overrides=image_overrides)
} }
......
...@@ -16,6 +16,7 @@ from transformers.models.gemma3n import (Gemma3nAudioConfig, ...@@ -16,6 +16,7 @@ from transformers.models.gemma3n import (Gemma3nAudioConfig,
from transformers.models.siglip import SiglipImageProcessorFast from transformers.models.siglip import SiglipImageProcessorFast
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -153,6 +154,7 @@ class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]): ...@@ -153,6 +154,7 @@ class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
num_audios = mm_counts.get("audio", 0) num_audios = mm_counts.get("audio", 0)
...@@ -163,13 +165,19 @@ class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]): ...@@ -163,13 +165,19 @@ class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]):
img_width = image_processor.size.get("width", 224) img_width = image_processor.size.get("width", 224)
img_height = image_processor.size.get("height", 224) img_height = image_processor.size.get("height", 224)
image_overrides = mm_options.get("image") if mm_options else None
audio_overrides = mm_options.get("audio") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=img_width, self._get_dummy_images(width=img_width,
height=img_height, height=img_height,
num_images=num_images), num_images=num_images,
overrides=image_overrides),
"audio": "audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios) self._get_dummy_audios(length=audio_len,
num_audios=num_audios,
overrides=audio_overrides)
} }
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import partial from functools import partial
from typing import Annotated, Any, Callable, Literal, Optional, Union from typing import Annotated, Any, Callable, Literal, Optional, Union, override
import numpy as np import numpy as np
import torch import torch
...@@ -50,6 +50,7 @@ from vllm.attention.backends.registry import _Backend ...@@ -50,6 +50,7 @@ from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import (check_upstream_fa_availability, from vllm.attention.layer import (check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend) maybe_get_vit_flash_attn_backend)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import (get_tensor_model_parallel_world_size, from vllm.distributed import (get_tensor_model_parallel_world_size,
parallel_state) parallel_state)
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
...@@ -1110,6 +1111,7 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]): ...@@ -1110,6 +1111,7 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0) num_videos = mm_counts.get("video", 0)
...@@ -1118,17 +1120,23 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]): ...@@ -1118,17 +1120,23 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):
self.info.get_image_size_with_most_features()) self.info.get_image_size_with_most_features())
target_num_frames = self.info.get_num_frames_with_most_features( target_num_frames = self.info.get_num_frames_with_most_features(
seq_len, mm_counts) seq_len, mm_counts)
image_overrides = mm_options.get("image") if mm_options else None
video_overrides = mm_options.get("video") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images), num_images=num_images,
overrides=image_overrides),
"video": "video":
self._get_dummy_videos( self._get_dummy_videos(
width=target_width, width=target_width,
height=target_height, height=target_height,
num_frames=target_num_frames, num_frames=target_num_frames,
num_videos=num_videos, num_videos=num_videos,
overrides=video_overrides,
), ),
} }
...@@ -1139,7 +1147,31 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]): ...@@ -1139,7 +1147,31 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):
height: int, height: int,
num_frames: int, num_frames: int,
num_videos: int, num_videos: int,
overrides: Optional[VideoDummyOptions] = None,
) -> list[VideoItem]: ) -> list[VideoItem]:
if overrides:
if overrides.num_frames:
if overrides.num_frames > num_frames:
logger.warning(
"video.num_frames override (%d) exceeds model's "
"maximum number of frames (%d), will be ignored",
overrides.num_frames, num_frames)
num_frames = min(num_frames, overrides.num_frames)
if overrides.width:
if overrides.width > width:
logger.warning(
"video.width override (%d) exceeds model's "
"maximum width (%d), will be ignored", overrides.width,
width)
width = min(width, overrides.width)
if overrides.height:
if overrides.height > height:
logger.warning(
"video.height override (%d) exceeds model's "
"maximum height (%d), will be ignored",
overrides.height, height)
height = min(height, override.height)
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8) video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
video_items = [] video_items = []
for i in range(num_videos): for i in range(num_videos):
......
...@@ -19,6 +19,7 @@ from transformers.tokenization_utils_base import TextInput ...@@ -19,6 +19,7 @@ from transformers.tokenization_utils_base import TextInput
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layer import MultiHeadAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -465,6 +466,7 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]): ...@@ -465,6 +466,7 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
...@@ -472,11 +474,14 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]): ...@@ -472,11 +474,14 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
target_width = target_height = vision_config["image_size"] target_width = target_height = vision_config["image_size"]
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
image_overrides = mm_options.get("image") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images) num_images=num_images,
overrides=image_overrides)
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment