Unverified Commit 05fb6718 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Clean up multi-modal processors (#14417)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 12c29a88
...@@ -2405,6 +2405,15 @@ class MultiModalConfig: ...@@ -2405,6 +2405,15 @@ class MultiModalConfig:
hash_str = hashlib.md5(str(factors).encode()).hexdigest() hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str return hash_str
def get_limit_per_prompt(self, modality: str) -> int:
"""
Get the maximum number of input items allowed per prompt
for the given modality.
If not set by the user, this defaults to `1`.
"""
return self.limit_per_prompt.get(modality, 1)
# TODO: Add configs to init vision tower or not. # TODO: Add configs to init vision tower or not.
......
...@@ -14,7 +14,6 @@ from einops import rearrange, repeat ...@@ -14,7 +14,6 @@ from einops import rearrange, repeat
from transformers import BatchFeature from transformers import BatchFeature
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
...@@ -25,8 +24,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, ...@@ -25,8 +24,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems) ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache, BaseProcessingInfo, PromptReplacement,
PromptReplacement, PromptUpdate) PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config, from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
...@@ -42,8 +41,6 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, ...@@ -42,8 +41,6 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
logger = init_logger(__name__)
# The image token id may be various # The image token id may be various
_IMAGE_TOKEN = "<image>" _IMAGE_TOKEN = "<image>"
...@@ -216,30 +213,6 @@ class DeepseekVL2DummyInputsBuilder( ...@@ -216,30 +213,6 @@ class DeepseekVL2DummyInputsBuilder(
class DeepseekVL2MultiModalProcessor( class DeepseekVL2MultiModalProcessor(
BaseMultiModalProcessor[DeepseekVL2ProcessingInfo]): BaseMultiModalProcessor[DeepseekVL2ProcessingInfo]):
def __init__(
self,
info: DeepseekVL2ProcessingInfo,
dummy_inputs: "BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]",
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
super().__init__(
info,
dummy_inputs,
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)
mm_limit = self.info.ctx.model_config.multimodal_config.limit_per_prompt
if self.cache is not None and mm_limit["image"] > 2:
# The processor output depends on the number of images passed,
# making it incompatible with processing cache which is supposed
# to be invariant of how many images are passed per prompt
self.cache = None
logger.warning_once(
f"{type(self).__name__} does not support processing cache with "
"image limit larger than 2.")
def _call_hf_processor( def _call_hf_processor(
self, self,
prompt: str, prompt: str,
...@@ -316,6 +289,31 @@ class DeepseekVL2MultiModalProcessor( ...@@ -316,6 +289,31 @@ class DeepseekVL2MultiModalProcessor(
) )
] ]
def _cached_apply_hf_processor(
self,
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]:
# The processor logic is different for len(images) <= 2 vs > 2
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
# perform caching for the most common case
if mm_data_items.get_count("image", strict=False) > 2:
# This code path corresponds to the cache being disabled
return self._apply_hf_processor_main(
prompt=prompt,
mm_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_update=True,
)
return super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
DeepseekVL2MultiModalProcessor, DeepseekVL2MultiModalProcessor,
......
...@@ -8,21 +8,19 @@ ...@@ -8,21 +8,19 @@
# Licensed under Apache 2.0 License [see LICENSE for details] # Licensed under Apache 2.0 License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Optional from typing import Optional, Union
import torch import torch
from PIL import Image from PIL import Image
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs from vllm.multimodal.inputs import MultiModalKwargs
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement, from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
PromptUpdate, PromptUpdateDetails) PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from .intern_vit import InternVisionModel from .intern_vit import InternVisionModel
...@@ -32,8 +30,6 @@ from .internvl import (IMG_CONTEXT, IMG_END, IMG_START, ...@@ -32,8 +30,6 @@ from .internvl import (IMG_CONTEXT, IMG_END, IMG_START,
InternVLMultiModalProcessor, build_transform, InternVLMultiModalProcessor, build_transform,
find_closest_aspect_ratio, get_internvl_target_ratios) find_closest_aspect_ratio, get_internvl_target_ratios)
logger = init_logger(__name__)
def resolve_h2ovl_min_max_num( def resolve_h2ovl_min_max_num(
*, *,
...@@ -465,29 +461,6 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo): ...@@ -465,29 +461,6 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo] class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
): ):
def __init__(self,
info: H2OVLProcessingInfo,
dummy_inputs: "BaseDummyInputsBuilder[H2OVLProcessingInfo]",
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
super().__init__(
info,
dummy_inputs,
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)
mm_limit = self.info.ctx.model_config.multimodal_config.limit_per_prompt
if self.cache is not None and mm_limit["image"] >= 2:
# The processor output depends on the number of images passed,
# making it incompatible with processing cache which is supposed
# to be invariant of how many images are passed per prompt
self.cache = None
logger.warning_once(
f"{type(self).__name__} does not support processing cache with "
"multi-image support enabled.")
def _get_prompt_updates( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
...@@ -543,6 +516,31 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo] ...@@ -543,6 +516,31 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
) )
] ]
def _cached_apply_hf_processor(
self,
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]:
# The processor logic is different for len(images) <= 1 vs > 1
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
# perform caching for the most common case
if mm_data_items.get_count("image", strict=False) > 1:
# This code path corresponds to the cache being disabled
return self._apply_hf_processor_main(
prompt=prompt,
mm_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_update=True,
)
return super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
H2OVLMultiModalProcessor, H2OVLMultiModalProcessor,
......
...@@ -133,7 +133,7 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo): ...@@ -133,7 +133,7 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
def get_num_frames_with_most_features(self, seq_len: int) -> int: def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config() mm_config = self.ctx.get_mm_config()
max_videos = mm_config.limit_per_prompt.get("video", 1) max_videos = mm_config.get_limit_per_prompt("video")
max_total_frames = self._get_max_video_frames(seq_len) max_total_frames = self._get_max_video_frames(seq_len)
......
...@@ -206,8 +206,8 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo): ...@@ -206,8 +206,8 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
def get_num_frames_with_most_features(self, seq_len: int) -> int: def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config() mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1) max_images = mm_config.get_limit_per_prompt("image")
max_videos = mm_config.limit_per_prompt.get("video", 1) max_videos = mm_config.get_limit_per_prompt("video")
max_image_tokens = self.get_max_image_tokens() * max_images max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len - max_total_frames = self._get_max_video_frames(seq_len -
......
...@@ -201,9 +201,9 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): ...@@ -201,9 +201,9 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
def get_num_frames_with_most_features(self, seq_len: int) -> int: def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config() mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1) max_images = mm_config.get_limit_per_prompt("image")
max_videos = mm_config.limit_per_prompt.get("video", 1) max_videos = mm_config.get_limit_per_prompt("video")
max_audios = mm_config.limit_per_prompt.get("audio", 1) max_audios = mm_config.get_limit_per_prompt("audio")
# count <image_idx></image_idx> tokens # count <image_idx></image_idx> tokens
# which are not in get_max_image_tokens # which are not in get_max_image_tokens
......
...@@ -446,8 +446,8 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo): ...@@ -446,8 +446,8 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
def get_num_frames_with_most_features(self, seq_len: int) -> int: def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config() mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1) max_images = mm_config.get_limit_per_prompt("image")
max_videos = mm_config.limit_per_prompt.get("video", 1) max_videos = mm_config.get_limit_per_prompt("video")
# count <image_idx></image_idx> tokens # count <image_idx></image_idx> tokens
# which are not in get_max_image_tokens # which are not in get_max_image_tokens
......
...@@ -68,7 +68,7 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, ...@@ -68,7 +68,7 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
image_token_id = mm_encoder.special_ids.img image_token_id = mm_encoder.special_ids.img
mm_config = ctx.get_mm_config() mm_config = ctx.get_mm_config()
num_images = mm_config.limit_per_prompt.get("image", 1) num_images = mm_config.get_limit_per_prompt("image")
# dummy size # dummy size
size = 256 size = 256
......
...@@ -911,8 +911,8 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): ...@@ -911,8 +911,8 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
def get_num_frames_with_most_features(self, seq_len: int) -> int: def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config() mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1) max_images = mm_config.get_limit_per_prompt("image")
max_videos = mm_config.limit_per_prompt.get("video", 1) max_videos = mm_config.get_limit_per_prompt("video")
max_image_tokens = self.get_max_image_tokens() * max_images max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len - max_total_frames = self._get_max_video_frames(seq_len -
......
...@@ -984,10 +984,10 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -984,10 +984,10 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
before passing them to :meth:`_get_hf_mm_data`. before passing them to :meth:`_get_hf_mm_data`.
""" """
mm_items = self.data_parser.parse_mm_data(mm_data) mm_items = self.data_parser.parse_mm_data(mm_data)
mm_config = self.info.ctx.get_mm_config()
mm_limits = self.info.ctx.get_mm_config().limit_per_prompt
for modality, items in mm_items.items(): for modality, items in mm_items.items():
limit = mm_limits.get(modality, 1) limit = mm_config.get_limit_per_prompt(modality)
if len(items) > limit: if len(items) > limit:
raise ValueError( raise ValueError(
f"You set {modality}={limit} (or defaulted to 1) in " f"You set {modality}={limit} (or defaulted to 1) in "
......
...@@ -110,12 +110,10 @@ class MultiModalProfiler(Generic[_I]): ...@@ -110,12 +110,10 @@ class MultiModalProfiler(Generic[_I]):
def get_mm_limits(self) -> Mapping[str, int]: def get_mm_limits(self) -> Mapping[str, int]:
mm_config = self.processing_info.ctx.get_mm_config() mm_config = self.processing_info.ctx.get_mm_config()
mm_limit_per_prompt = mm_config.limit_per_prompt
supported_mm_limits = self.processing_info.get_supported_mm_limits() supported_mm_limits = self.processing_info.get_supported_mm_limits()
mm_limits = { mm_limits = {
modality: mm_limit_per_prompt.get(modality, 1) modality: mm_config.get_limit_per_prompt(modality)
for modality in supported_mm_limits for modality in supported_mm_limits
} }
......
...@@ -355,7 +355,7 @@ class MultiModalRegistry: ...@@ -355,7 +355,7 @@ class MultiModalRegistry:
# TODO: Automatically determine the limits based on budget # TODO: Automatically determine the limits based on budget
# once more models support multi-image inputs # once more models support multi-image inputs
limits_per_plugin = { limits_per_plugin = {
key: config_limits_per_plugin.get(key, 1) key: multimodal_config.get_limit_per_prompt(key)
for key in self._plugins for key in self._plugins
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment