Unverified Commit 79aa2446 authored by Wenlong Wang's avatar Wenlong Wang Committed by GitHub
Browse files

[Multi Modal] Configurable MM Profiling (#25631)


Signed-off-by: default avatarwwl2755 <wangwenlong2755@gmail.com>
Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 2ed3f20d
...@@ -8,6 +8,7 @@ from torch import nn ...@@ -8,6 +8,7 @@ from torch import nn
from transformers import BatchFeature, PaliGemmaConfig from transformers import BatchFeature, PaliGemmaConfig
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
...@@ -106,6 +107,7 @@ class PaliGemmaDummyInputsBuilder( ...@@ -106,6 +107,7 @@ class PaliGemmaDummyInputsBuilder(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
...@@ -113,11 +115,14 @@ class PaliGemmaDummyInputsBuilder( ...@@ -113,11 +115,14 @@ class PaliGemmaDummyInputsBuilder(
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
image_overrides = mm_options.get("image") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=max_image_size, self._get_dummy_images(width=max_image_size,
height=max_image_size, height=max_image_size,
num_images=num_images) num_images=num_images,
overrides=image_overrides)
} }
......
...@@ -25,6 +25,7 @@ from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig, ...@@ -25,6 +25,7 @@ from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig,
ProcessorMixin) ProcessorMixin)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -356,17 +357,21 @@ class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]): ...@@ -356,17 +357,21 @@ class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images) num_images=num_images,
overrides=image_overrides)
} }
......
...@@ -17,6 +17,7 @@ from transformers.models.phi4_multimodal.modeling_phi4_multimodal import ( ...@@ -17,6 +17,7 @@ from transformers.models.phi4_multimodal.modeling_phi4_multimodal import (
Phi4MultimodalAudioRelativeAttentionBias, adaptive_enc_mask, unfold_tensor) Phi4MultimodalAudioRelativeAttentionBias, adaptive_enc_mask, unfold_tensor)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import (divide, get_tensor_model_parallel_rank, from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
...@@ -980,6 +981,7 @@ class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]): ...@@ -980,6 +981,7 @@ class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_audios = mm_counts.get("audio", 0) num_audios = mm_counts.get("audio", 0)
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
...@@ -987,14 +989,19 @@ class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]): ...@@ -987,14 +989,19 @@ class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]):
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
audio_overrides = mm_options.get("audio") if mm_options else None
mm_data = { mm_data = {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images), num_images=num_images,
overrides=image_overrides),
"audio": "audio":
self._get_dummy_audios(length=_AUDIO_MAX_SOUNDFILE_SIZE, self._get_dummy_audios(length=_AUDIO_MAX_SOUNDFILE_SIZE,
num_audios=num_audios), num_audios=num_audios,
overrides=audio_overrides),
} }
return mm_data return mm_data
......
...@@ -11,6 +11,7 @@ from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin, ...@@ -11,6 +11,7 @@ from transformers import (BatchFeature, PretrainedConfig, ProcessorMixin,
SequenceFeatureExtractor, SiglipVisionConfig) SequenceFeatureExtractor, SiglipVisionConfig)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -749,6 +750,7 @@ class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]): ...@@ -749,6 +750,7 @@ class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_audios = mm_counts.get("audio", 0) num_audios = mm_counts.get("audio", 0)
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
...@@ -756,14 +758,19 @@ class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]): ...@@ -756,14 +758,19 @@ class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]):
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
audio_overrides = mm_options.get("audio") if mm_options else None
mm_data = { mm_data = {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images), num_images=num_images,
overrides=image_overrides),
"audio": "audio":
self._get_dummy_audios(length=_AUDIO_MAX_SOUNDFILE_SIZE, self._get_dummy_audios(length=_AUDIO_MAX_SOUNDFILE_SIZE,
num_audios=num_audios), num_audios=num_audios,
overrides=audio_overrides),
} }
return mm_data return mm_data
......
...@@ -24,6 +24,7 @@ from transformers.models.pixtral.modeling_pixtral import ( ...@@ -24,6 +24,7 @@ from transformers.models.pixtral.modeling_pixtral import (
from transformers.tokenization_utils_base import TextInput from transformers.tokenization_utils_base import TextInput
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -228,28 +229,33 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]): ...@@ -228,28 +229,33 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
image_overrides = mm_options.get("image") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images) num_images=num_images,
overrides=image_overrides)
} }
def get_dummy_processor_inputs( def get_dummy_processor_inputs(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
dummy_text = self.get_dummy_text(mm_counts) dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts) dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
dummy_images = dummy_mm_data.get("image", []) dummy_images = dummy_mm_data.get("image", [])
tokenization_kwargs = {"truncation": False} tokenization_kwargs = {"truncation": False}
......
...@@ -39,6 +39,7 @@ from transformers.models.qwen2_5_omni.processing_qwen2_5_omni import ( ...@@ -39,6 +39,7 @@ from transformers.models.qwen2_5_omni.processing_qwen2_5_omni import (
from transformers.models.whisper import WhisperFeatureExtractor from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
...@@ -212,6 +213,7 @@ class Qwen2_5OmniThinkerDummyInputsBuilder( ...@@ -212,6 +213,7 @@ class Qwen2_5OmniThinkerDummyInputsBuilder(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_audios = mm_counts.get("audio", 0) num_audios = mm_counts.get("audio", 0)
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
...@@ -228,19 +230,26 @@ class Qwen2_5OmniThinkerDummyInputsBuilder( ...@@ -228,19 +230,26 @@ class Qwen2_5OmniThinkerDummyInputsBuilder(
target_num_frames = \ target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len, mm_counts) self.info.get_num_frames_with_most_features(seq_len, mm_counts)
image_overrides = mm_options.get("image") if mm_options else None
video_overrides = mm_options.get("video") if mm_options else None
audio_overrides = mm_options.get("audio") if mm_options else None
mm_data = { mm_data = {
"audio": "audio":
self._get_dummy_audios(length=target_audio_length, self._get_dummy_audios(length=target_audio_length,
num_audios=num_audios), num_audios=num_audios,
overrides=audio_overrides),
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images), num_images=num_images,
overrides=image_overrides),
"video": "video":
self._get_dummy_videos(width=target_width, self._get_dummy_videos(width=target_width,
height=target_height, height=target_height,
num_frames=target_num_frames, num_frames=target_num_frames,
num_videos=num_videos), num_videos=num_videos,
overrides=video_overrides),
} }
return mm_data return mm_data
......
...@@ -34,6 +34,7 @@ from transformers.models.qwen2_audio import (Qwen2AudioConfig, ...@@ -34,6 +34,7 @@ from transformers.models.qwen2_audio import (Qwen2AudioConfig,
from transformers.models.whisper import WhisperFeatureExtractor from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (AudioItem, ModalityData, from vllm.multimodal.inputs import (AudioItem, ModalityData,
MultiModalDataDict, MultiModalFieldConfig, MultiModalDataDict, MultiModalFieldConfig,
...@@ -144,6 +145,7 @@ class Qwen2AudioDummyInputsBuilder( ...@@ -144,6 +145,7 @@ class Qwen2AudioDummyInputsBuilder(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
feature_extractor = self.info.get_feature_extractor() feature_extractor = self.info.get_feature_extractor()
...@@ -151,9 +153,13 @@ class Qwen2AudioDummyInputsBuilder( ...@@ -151,9 +153,13 @@ class Qwen2AudioDummyInputsBuilder(
audio_len = feature_extractor.chunk_length * sampling_rate audio_len = feature_extractor.chunk_length * sampling_rate
num_audios = mm_counts.get("audio", 0) num_audios = mm_counts.get("audio", 0)
audio_overrides = mm_options.get("audio") if mm_options else None
return { return {
"audio": "audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios) self._get_dummy_audios(length=audio_len,
num_audios=num_audios,
overrides=audio_overrides)
} }
......
...@@ -45,6 +45,7 @@ from vllm.attention.backends.registry import _Backend ...@@ -45,6 +45,7 @@ from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import (check_upstream_fa_availability, from vllm.attention.layer import (check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend) maybe_get_vit_flash_attn_backend)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -1034,6 +1035,7 @@ class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]): ...@@ -1034,6 +1035,7 @@ class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0) num_videos = mm_counts.get("video", 0)
...@@ -1043,17 +1045,22 @@ class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]): ...@@ -1043,17 +1045,22 @@ class Qwen2VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2VLProcessingInfo]):
target_num_frames = \ target_num_frames = \
self.info.get_num_frames_with_most_features(seq_len, mm_counts) self.info.get_num_frames_with_most_features(seq_len, mm_counts)
image_overrides = mm_options.get("image") if mm_options else None
video_overrides = mm_options.get("video") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images), num_images=num_images,
overrides=image_overrides),
"video": "video":
self._get_dummy_videos( self._get_dummy_videos(
width=target_width, width=target_width,
height=target_height, height=target_height,
num_frames=target_num_frames, num_frames=target_num_frames,
num_videos=num_videos, num_videos=num_videos,
overrides=video_overrides,
) )
} }
......
...@@ -47,6 +47,7 @@ from vllm.attention.backends.registry import _Backend ...@@ -47,6 +47,7 @@ from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
...@@ -736,6 +737,7 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]): ...@@ -736,6 +737,7 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0) num_videos = mm_counts.get("video", 0)
...@@ -750,17 +752,23 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]): ...@@ -750,17 +752,23 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
num_frames=target_num_frames, num_frames=target_num_frames,
image_processor=self.info.get_video_processor(), image_processor=self.info.get_video_processor(),
) )
image_overrides = mm_options.get("image") if mm_options else None
video_overrides = mm_options.get("video") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images), num_images=num_images,
overrides=image_overrides),
"video": "video":
self._get_dummy_videos( self._get_dummy_videos(
width=target_video_size.width, width=target_video_size.width,
height=target_video_size.height, height=target_video_size.height,
num_frames=target_num_frames, num_frames=target_num_frames,
num_videos=num_videos, num_videos=num_videos,
overrides=video_overrides,
), ),
} }
......
...@@ -24,6 +24,7 @@ from transformers.image_utils import ImageInput ...@@ -24,6 +24,7 @@ from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput from transformers.tokenization_utils_base import TextInput
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear, ReplicatedLinear,
...@@ -567,6 +568,7 @@ class QwenVLDummyInputsBuilder(BaseDummyInputsBuilder[QwenVLProcessingInfo]): ...@@ -567,6 +568,7 @@ class QwenVLDummyInputsBuilder(BaseDummyInputsBuilder[QwenVLProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
vision_config = hf_config.visual vision_config = hf_config.visual
...@@ -574,11 +576,14 @@ class QwenVLDummyInputsBuilder(BaseDummyInputsBuilder[QwenVLProcessingInfo]): ...@@ -574,11 +576,14 @@ class QwenVLDummyInputsBuilder(BaseDummyInputsBuilder[QwenVLProcessingInfo]):
target_width = target_height = vision_config["image_size"] target_width = target_height = vision_config["image_size"]
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
image_overrides = mm_options.get("image") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images) num_images=num_images,
overrides=image_overrides)
} }
......
...@@ -2,12 +2,14 @@ ...@@ -2,12 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping from collections.abc import Mapping
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers.activations import GELUActivation from transformers.activations import GELUActivation
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict from vllm.multimodal.inputs import MultiModalDataDict
...@@ -38,17 +40,21 @@ class RVLDummyInputsBuilder(LlavaDummyInputsBuilder[RVLProcessingInfo]): ...@@ -38,17 +40,21 @@ class RVLDummyInputsBuilder(LlavaDummyInputsBuilder[RVLProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
target_width, target_height = ( target_width, target_height = (
self.info.get_image_size_with_most_features()) self.info.get_image_size_with_most_features())
image_overrides = mm_options.get("image") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images), num_images=num_images,
overrides=image_overrides),
} }
......
...@@ -17,6 +17,7 @@ from PIL import Image ...@@ -17,6 +17,7 @@ from PIL import Image
from transformers import BatchFeature, PretrainedConfig, TensorType from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
...@@ -505,16 +506,20 @@ class SkyworkR1VDummyInputsBuilder( ...@@ -505,16 +506,20 @@ class SkyworkR1VDummyInputsBuilder(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
image_overrides = mm_options.get("image") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images) num_images=num_images,
overrides=image_overrides)
} }
......
...@@ -17,6 +17,7 @@ from transformers import BatchFeature, PretrainedConfig, TensorType ...@@ -17,6 +17,7 @@ from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layer import MultiHeadAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -496,16 +497,20 @@ class Step3VLDummyInputsBuilder(BaseDummyInputsBuilder[Step3VLProcessingInfo]): ...@@ -496,16 +497,20 @@ class Step3VLDummyInputsBuilder(BaseDummyInputsBuilder[Step3VLProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
target_width, target_height = \ target_width, target_height = \
self.info.get_image_size_with_most_features() self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
image_overrides = mm_options.get("image") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images) num_images=num_images,
overrides=image_overrides)
} }
......
...@@ -28,6 +28,8 @@ from terratorch.vllm import (DummyDataGenerator, InferenceRunner, ...@@ -28,6 +28,8 @@ from terratorch.vllm import (DummyDataGenerator, InferenceRunner,
from transformers import BatchFeature from transformers import BatchFeature
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.model_executor.models.utils import AutoWeightsLoader
...@@ -48,6 +50,8 @@ from .interfaces import (IsAttentionFree, MultiModalEmbeddings, ...@@ -48,6 +50,8 @@ from .interfaces import (IsAttentionFree, MultiModalEmbeddings,
SupportsMultiModal) SupportsMultiModal)
from .interfaces_base import default_pooling_type from .interfaces_base import default_pooling_type
logger = init_logger(__name__)
def _terratorch_field_names(pretrained_cfg: dict): def _terratorch_field_names(pretrained_cfg: dict):
input_definition = InputDefinition(**pretrained_cfg["input"]) input_definition = InputDefinition(**pretrained_cfg["input"])
...@@ -97,9 +101,16 @@ class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]): ...@@ -97,9 +101,16 @@ class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
# Dummy data is generated based on the 'input' section # Dummy data is generated based on the 'input' section
# defined in the HF configuration file # defined in the HF configuration file
if mm_options:
logger.warning("Configurable multimodal profiling "
"options are not supported for Terratorch. "
"They are ignored for now.")
return self.dummy_data_generator.get_dummy_mm_data() return self.dummy_data_generator.get_dummy_mm_data()
......
...@@ -33,6 +33,7 @@ from vllm.attention import Attention, AttentionType ...@@ -33,6 +33,7 @@ from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, VllmConfig) ParallelConfig, VllmConfig)
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.utils import getattr_iter from vllm.config.utils import getattr_iter
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.utils import get_pp_indices from vllm.distributed.utils import get_pp_indices
...@@ -285,16 +286,20 @@ class MultiModalDummyInputsBuilder( ...@@ -285,16 +286,20 @@ class MultiModalDummyInputsBuilder(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
target_width, target_height = self.info.get_max_image_size() target_width, target_height = self.info.get_max_image_size()
image_overrides = mm_options.get("image") if mm_options else None
return { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
height=target_height, height=target_height,
num_images=num_images), num_images=num_images,
overrides=image_overrides),
} }
......
...@@ -14,6 +14,7 @@ from transformers.models.whisper import WhisperFeatureExtractor ...@@ -14,6 +14,7 @@ from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder from transformers.models.whisper.modeling_whisper import WhisperEncoder
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.model_loader import DefaultModelLoader from vllm.model_executor.model_loader import DefaultModelLoader
...@@ -114,6 +115,7 @@ class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo] ...@@ -114,6 +115,7 @@ class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
feature_extractor = self.info.get_feature_extractor() feature_extractor = self.info.get_feature_extractor()
...@@ -122,9 +124,13 @@ class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo] ...@@ -122,9 +124,13 @@ class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo]
_MAX_ENCODER_BATCH_SIZE) _MAX_ENCODER_BATCH_SIZE)
num_audios = mm_counts.get("audio", 0) num_audios = mm_counts.get("audio", 0)
audio_overrides = mm_options.get("audio") if mm_options else None
return { return {
"audio": "audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios) self._get_dummy_audios(length=audio_len,
num_audios=num_audios,
overrides=audio_overrides)
} }
......
...@@ -21,6 +21,7 @@ from transformers import BatchFeature, TensorType, WhisperConfig ...@@ -21,6 +21,7 @@ from transformers import BatchFeature, TensorType, WhisperConfig
from transformers.tokenization_utils_base import TextInput from transformers.tokenization_utils_base import TextInput
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -204,25 +205,31 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]): ...@@ -204,25 +205,31 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_audios = mm_counts.get("audio", 0) num_audios = mm_counts.get("audio", 0)
target_length = self.info.get_max_audio_array_len() target_length = self.info.get_max_audio_array_len()
audio_overrides = mm_options.get("audio") if mm_options else None
return { return {
"audio": "audio":
self._get_dummy_audios(length=target_length, num_audios=num_audios) self._get_dummy_audios(length=target_length,
num_audios=num_audios,
overrides=audio_overrides)
} }
def get_dummy_processor_inputs( def get_dummy_processor_inputs(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
dummy_text = self.get_dummy_text(mm_counts) dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts) dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
dummy_audios = dummy_mm_data.get("audio", []) dummy_audios = dummy_mm_data.get("audio", [])
audio_chunks: list[AudioChunk] = [] audio_chunks: list[AudioChunk] = []
......
...@@ -18,6 +18,7 @@ from vllm.attention.layer import MultiHeadAttention ...@@ -18,6 +18,7 @@ from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.cross_attention import CrossAttention from vllm.attention.layers.cross_attention import CrossAttention
from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig, from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig,
VllmConfig) VllmConfig)
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -691,6 +692,7 @@ class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]): ...@@ -691,6 +692,7 @@ class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
feature_extractor = self.info.get_feature_extractor() feature_extractor = self.info.get_feature_extractor()
...@@ -698,9 +700,13 @@ class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]): ...@@ -698,9 +700,13 @@ class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
audio_len = feature_extractor.chunk_length * sampling_rate audio_len = feature_extractor.chunk_length * sampling_rate
num_audios = mm_counts.get("audio", 0) num_audios = mm_counts.get("audio", 0)
audio_overrides = mm_options.get("audio") if mm_options else None
return { return {
"audio": "audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios) self._get_dummy_audios(length=audio_len,
num_audios=num_audios,
overrides=audio_overrides)
} }
......
...@@ -10,6 +10,8 @@ import numpy.typing as npt ...@@ -10,6 +10,8 @@ import numpy.typing as npt
from PIL import Image from PIL import Image
import vllm.envs as envs import vllm.envs as envs
from vllm.config.multimodal import (AudioDummyOptions, BaseDummyOptions,
ImageDummyOptions, VideoDummyOptions)
from vllm.logger import init_logger from vllm.logger import init_logger
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
...@@ -73,10 +75,19 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): ...@@ -73,10 +75,19 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalDataDict: ) -> MultiModalDataDict:
""" """
Build the multimodal input which, after processing, results in Build the multimodal input which, after processing, results in
the maximum possible number of placeholder tokens. the maximum possible number of placeholder tokens.
Args:
seq_len: Sequence length
mm_counts: Count of items per modality
mm_options: Configurable options per modality (optional).
If None, use model defaults for backward compatibility.
If provided, models can use these to customize dummy
data generation.
""" """
raise NotImplementedError raise NotImplementedError
...@@ -84,13 +95,22 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): ...@@ -84,13 +95,22 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> ProcessorInputs: ) -> ProcessorInputs:
""" """
Build the input which, after processing, results in Build the input which, after processing, results in
the maximum possible number of placeholder tokens. the maximum possible number of placeholder tokens.
Args:
seq_len: Sequence length
mm_counts: Count of items per modality
mm_options: Configurable options per modality (optional)
""" """
dummy_text = self.get_dummy_text(mm_counts) dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts)
# Use the unified function for both legacy and configurable cases
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
tokenization_kwargs = {"truncation": False} tokenization_kwargs = {"truncation": False}
return ProcessorInputs(prompt=dummy_text, return ProcessorInputs(prompt=dummy_text,
...@@ -102,9 +122,17 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): ...@@ -102,9 +122,17 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
*, *,
length: int, length: int,
num_audios: int, num_audios: int,
overrides: Optional[AudioDummyOptions] = None,
) -> list[npt.NDArray]: ) -> list[npt.NDArray]:
if num_audios == 0: if num_audios == 0:
return [] return []
if overrides and overrides.length:
if overrides.length > length:
logger.warning(
"audio.length override (%d) exceeds model's "
"maximum length (%d), will be ignored", overrides.length,
length)
length = min(length, overrides.length)
audio = np.zeros((length, )) audio = np.zeros((length, ))
return [audio] * num_audios return [audio] * num_audios
...@@ -114,9 +142,25 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): ...@@ -114,9 +142,25 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
width: int, width: int,
height: int, height: int,
num_images: int, num_images: int,
overrides: Optional[ImageDummyOptions] = None,
) -> list[Image.Image]: ) -> list[Image.Image]:
if num_images == 0: if num_images == 0:
return [] return []
if overrides:
if overrides.width:
if overrides.width > width:
logger.warning(
"image.width override (%d) exceeds model's "
"maximum width (%d), will be ignored", overrides.width,
width)
width = min(width, overrides.width)
if overrides.height:
if overrides.height > height:
logger.warning(
"image.height override (%d) exceeds model's "
"maximum height (%d), will be ignored",
overrides.height, height)
height = min(height, overrides.height)
image = Image.new("RGB", (width, height), color=255) image = Image.new("RGB", (width, height), color=255)
return [image] * num_images return [image] * num_images
...@@ -127,9 +171,32 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): ...@@ -127,9 +171,32 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
height: int, height: int,
num_frames: int, num_frames: int,
num_videos: int, num_videos: int,
overrides: Optional[VideoDummyOptions] = None,
) -> list[npt.NDArray]: ) -> list[npt.NDArray]:
if num_videos == 0: if num_videos == 0:
return [] return []
if overrides:
if overrides.num_frames:
if overrides.num_frames > num_frames:
logger.warning(
"video.num_frames override (%d) exceeds model's "
"maximum number of frames (%d), will be ignored",
overrides.num_frames, num_frames)
num_frames = min(num_frames, overrides.num_frames)
if overrides.width:
if overrides.width > width:
logger.warning(
"video.width override (%d) exceeds model's "
"maximum width (%d), will be ignored", overrides.width,
width)
width = min(width, overrides.width)
if overrides.height:
if overrides.height > height:
logger.warning(
"video.height override (%d) exceeds model's "
"maximum height (%d), will be ignored",
overrides.height, height)
height = min(height, overrides.height)
video = np.full((num_frames, width, height, 3), 255) video = np.full((num_frames, width, height, 3), 255)
return [video] * num_videos return [video] * num_videos
...@@ -162,13 +229,14 @@ class MultiModalProfiler(Generic[_I]): ...@@ -162,13 +229,14 @@ class MultiModalProfiler(Generic[_I]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None, mm_counts: Optional[Mapping[str, int]] = None,
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
if mm_counts is None: if mm_counts is None:
mm_counts = self.get_mm_limits() mm_counts = self.get_mm_limits()
factory = self.dummy_inputs factory = self.dummy_inputs
processor_inputs = factory.get_dummy_processor_inputs( processor_inputs = factory.get_dummy_processor_inputs(
seq_len, mm_counts) seq_len, mm_counts, mm_options)
return self.processor.apply( return self.processor.apply(
prompt=processor_inputs.prompt, prompt=processor_inputs.prompt,
...@@ -195,8 +263,9 @@ class MultiModalProfiler(Generic[_I]): ...@@ -195,8 +263,9 @@ class MultiModalProfiler(Generic[_I]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None, mm_counts: Optional[Mapping[str, int]] = None,
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> DummyEncoderData: ) -> DummyEncoderData:
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options)
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs) mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)
# For encoder-decoder models, use encoder prompt token ids instead of # For encoder-decoder models, use encoder prompt token ids instead of
...@@ -228,8 +297,9 @@ class MultiModalProfiler(Generic[_I]): ...@@ -228,8 +297,9 @@ class MultiModalProfiler(Generic[_I]):
self, self,
seq_len: int, seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None, mm_counts: Optional[Mapping[str, int]] = None,
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
) -> DummyDecoderData: ) -> DummyDecoderData:
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options)
prompt_token_ids = mm_inputs["prompt_token_ids"] prompt_token_ids = mm_inputs["prompt_token_ids"]
total_len = len(prompt_token_ids) total_len = len(prompt_token_ids)
......
...@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar ...@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar
import torch.nn as nn import torch.nn as nn
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import (AnyTokenizer, from vllm.transformers_utils.tokenizer import (AnyTokenizer,
cached_tokenizer_from_config) cached_tokenizer_from_config)
...@@ -52,7 +53,7 @@ class DummyInputsBuilderFactory(Protocol[_I]): # type: ignore[misc] ...@@ -52,7 +53,7 @@ class DummyInputsBuilderFactory(Protocol[_I]): # type: ignore[misc]
... ...
class MultiModalProcessorFactory(Protocol[_I]): class MultiModalProcessorFactory(Protocol[_I]): # type: ignore[misc]
""" """
Constructs a Constructs a
[`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor] [`BaseMultiModalProcessor`][vllm.multimodal.processing.BaseMultiModalProcessor]
...@@ -95,6 +96,28 @@ class MultiModalRegistry: ...@@ -95,6 +96,28 @@ class MultiModalRegistry:
self._processor_factories = ClassRegistry[nn.Module, self._processor_factories = ClassRegistry[nn.Module,
_ProcessorFactories]() _ProcessorFactories]()
def _extract_mm_options(
self,
model_config: "ModelConfig",
) -> Optional[Mapping[str, BaseDummyOptions]]:
"""
Extract multimodal dummy options from model config.
Returns None if no configurable options are found, otherwise returns
a mapping of modality names to their dummy options.
"""
if not model_config.multimodal_config:
return None
mm_options = {
m: opt
for m in model_config.multimodal_config.limit_per_prompt
if (opt := model_config.multimodal_config.get_dummy_options(m)
) is not None
}
return mm_options if len(mm_options) > 0 else None
def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool: def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
""" """
Checks if the model supports multimodal inputs. Checks if the model supports multimodal inputs.
...@@ -135,7 +158,7 @@ class MultiModalRegistry: ...@@ -135,7 +158,7 @@ class MultiModalRegistry:
return {} return {}
processor = self.create_processor(model_config, cache=cache) processor = self.create_processor(model_config, cache=cache)
profiler = MultiModalProfiler(processor) profiler: MultiModalProfiler = MultiModalProfiler(processor)
seq_len = model_config.max_model_len seq_len = model_config.max_model_len
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
...@@ -189,7 +212,7 @@ class MultiModalRegistry: ...@@ -189,7 +212,7 @@ class MultiModalRegistry:
return {} return {}
processor = self.create_processor(model_config, cache=cache) processor = self.create_processor(model_config, cache=cache)
profiler = MultiModalProfiler(processor) profiler: MultiModalProfiler = MultiModalProfiler(processor)
return profiler.get_mm_limits() return profiler.get_mm_limits()
def register_processor( def register_processor(
...@@ -285,8 +308,15 @@ class MultiModalRegistry: ...@@ -285,8 +308,15 @@ class MultiModalRegistry:
The model is identified by ``model_config``. The model is identified by ``model_config``.
""" """
processor = self.create_processor(model_config, cache=cache) processor = self.create_processor(model_config, cache=cache)
profiler = MultiModalProfiler(processor) profiler: MultiModalProfiler = MultiModalProfiler(processor)
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts)
# Extract configurable options from multimodal config.
# Only include modalities that use advanced option types so legacy
# count-only behavior remains unchanged.
mm_options = self._extract_mm_options(model_config)
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts,
mm_options)
# Having more tokens is over-conservative but otherwise fine # Having more tokens is over-conservative but otherwise fine
token_ids = dummy_data.prompt_token_ids token_ids = dummy_data.prompt_token_ids
...@@ -311,8 +341,15 @@ class MultiModalRegistry: ...@@ -311,8 +341,15 @@ class MultiModalRegistry:
The model is identified by ``model_config``. The model is identified by ``model_config``.
""" """
processor = self.create_processor(model_config, cache=cache) processor = self.create_processor(model_config, cache=cache)
profiler = MultiModalProfiler(processor) profiler: MultiModalProfiler = MultiModalProfiler(processor)
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts)
# Extract configurable options from multimodal config.
# Only include modalities that use advanced option types so legacy
# count-only behavior remains unchanged.
mm_options = self._extract_mm_options(model_config)
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts,
mm_options)
# Having more tokens is over-conservative but otherwise fine # Having more tokens is over-conservative but otherwise fine
token_ids = dummy_data.prompt_token_ids token_ids = dummy_data.prompt_token_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment