Unverified Commit 365801fe authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Add max-count checking in data parser for single image models (#11661)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarRoger Wang <ywang@roblox.com>
Co-authored-by: default avatarRoger Wang <ywang@roblox.com>
parent 4db72e57
...@@ -566,7 +566,7 @@ See [this page](#generative-models) for more information on how to use generativ ...@@ -566,7 +566,7 @@ See [this page](#generative-models) for more information on how to use generativ
- [V1](gh-issue:8779) - [V1](gh-issue:8779)
* - `AriaForConditionalGeneration` * - `AriaForConditionalGeneration`
- Aria - Aria
- T + I - T + I<sup>+</sup>
- `rhymes-ai/Aria` - `rhymes-ai/Aria`
- -
- ✅︎ - ✅︎
......
...@@ -622,10 +622,11 @@ def _test_processing_cache_correctness( ...@@ -622,10 +622,11 @@ def _test_processing_cache_correctness(
# yapf: disable # yapf: disable
# True if the model supports multiple data items of the modality per request
@pytest.mark.parametrize(("model_id", "modalities"), [ @pytest.mark.parametrize(("model_id", "modalities"), [
("rhymes-ai/Aria", {"image": True}), ("rhymes-ai/Aria", {"image": True}),
("Salesforce/blip2-opt-2.7b", {"image": False}), ("Salesforce/blip2-opt-2.7b", {"image": False}),
("facebook/chameleon-7b", {"image": True}), ("facebook/chameleon-7b", {"image": False}),
("adept/fuyu-8b", {"image": False}), ("adept/fuyu-8b", {"image": False}),
("llava-hf/llava-1.5-7b-hf", {"image": True}), ("llava-hf/llava-1.5-7b-hf", {"image": True}),
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}), ("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
......
...@@ -18,6 +18,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -18,6 +18,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs, MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange) NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs, MultiModalDataItems, ProcessorInputs,
PromptReplacement) PromptReplacement)
...@@ -404,6 +405,9 @@ def get_max_blip2_image_tokens(ctx: InputContext): ...@@ -404,6 +405,9 @@ def get_max_blip2_image_tokens(ctx: InputContext):
class Blip2MultiModalProcessor(BaseMultiModalProcessor): class Blip2MultiModalProcessor(BaseMultiModalProcessor):
def _get_data_parser(self) -> MultiModalDataParser:
return MultiModalDataParser(max_mm_counts={"image": 1})
def _get_hf_processor(self) -> Blip2Processor: def _get_hf_processor(self) -> Blip2Processor:
return self.ctx.get_hf_processor(Blip2Processor) return self.ctx.get_hf_processor(Blip2Processor)
......
...@@ -31,6 +31,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -31,6 +31,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs, MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange) NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs, MultiModalDataItems, ProcessorInputs,
PromptReplacement) PromptReplacement)
...@@ -60,6 +61,9 @@ def get_max_chameleon_image_tokens(ctx: InputContext): ...@@ -60,6 +61,9 @@ def get_max_chameleon_image_tokens(ctx: InputContext):
class ChameleonMultiModalProcessor(BaseMultiModalProcessor): class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
def _get_data_parser(self) -> MultiModalDataParser:
return MultiModalDataParser(max_mm_counts={"image": 1})
def _get_hf_processor(self) -> ChameleonProcessor: def _get_hf_processor(self) -> ChameleonProcessor:
return self.ctx.get_hf_processor(ChameleonProcessor) return self.ctx.get_hf_processor(ChameleonProcessor)
......
...@@ -34,7 +34,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -34,7 +34,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs, MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange) NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import ImageProcessorItems from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs, MultiModalDataItems, ProcessorInputs,
PromptReplacement) PromptReplacement)
...@@ -54,7 +54,7 @@ MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920 ...@@ -54,7 +54,7 @@ MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920
class FuyuImagePatchInputs(TypedDict): class FuyuImagePatchInputs(TypedDict):
type: Literal["image_patches"] type: Literal["image_patches"]
data: torch.Tensor flat_data: torch.Tensor
""" """
Shape: Shape:
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)` `(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
...@@ -63,7 +63,7 @@ class FuyuImagePatchInputs(TypedDict): ...@@ -63,7 +63,7 @@ class FuyuImagePatchInputs(TypedDict):
patches_per_image: List[int] patches_per_image: List[int]
""" """
List of number of total patches for each image in the batch. List of number of total patches for each image in the batch.
This is used to restore the first two dimensions of `data`. This is used to restore the first two dimensions of `flat_data`.
""" """
...@@ -102,6 +102,9 @@ def get_max_fuyu_image_tokens(ctx: InputContext): ...@@ -102,6 +102,9 @@ def get_max_fuyu_image_tokens(ctx: InputContext):
class FuyuMultiModalProcessor(BaseMultiModalProcessor): class FuyuMultiModalProcessor(BaseMultiModalProcessor):
def _get_data_parser(self) -> MultiModalDataParser:
return MultiModalDataParser(max_mm_counts={"image": 1})
def _get_hf_processor(self) -> FuyuProcessor: def _get_hf_processor(self) -> FuyuProcessor:
return self.ctx.get_hf_processor(FuyuProcessor) return self.ctx.get_hf_processor(FuyuProcessor)
...@@ -304,7 +307,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -304,7 +307,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return FuyuImagePatchInputs( return FuyuImagePatchInputs(
type="image_patches", type="image_patches",
data=self._validate_pixel_values( flat_data=self._validate_pixel_values(
flatten_bn(image_patches_flat, concat=True)), flatten_bn(image_patches_flat, concat=True)),
patches_per_image=[x.size(0) for x in image_patches_flat], patches_per_image=[x.size(0) for x in image_patches_flat],
) )
...@@ -313,12 +316,13 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -313,12 +316,13 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def _process_image_input( def _process_image_input(
self, image_input: FuyuImagePatchInputs) -> NestedTensors: self, image_input: FuyuImagePatchInputs) -> NestedTensors:
image_patches = image_input["data"] image_patches_flat = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"] patches_per_image = image_input["patches_per_image"]
assert self.vision_embed_tokens is not None assert self.vision_embed_tokens is not None
vision_embeddings, _ = self.vision_embed_tokens(image_patches) vision_embeddings_flat, _ = self.vision_embed_tokens(
return vision_embeddings.split(patches_per_image, dim=0) image_patches_flat)
return vision_embeddings_flat.split(patches_per_image, dim=0)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
......
...@@ -220,11 +220,24 @@ ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]], ...@@ -220,11 +220,24 @@ ModalityDataParser: TypeAlias = Callable[[ModalityData[Any]],
class MultiModalDataParser: class MultiModalDataParser:
""" """
Parses :class:`MultiModalDataDict` into :class:`MultiModalDataItems`. Parses :class:`MultiModalDataDict` into :class:`MultiModalDataItems`.
Args:
max_mm_counts (Mapping[str, int]): The maximum allowed number of items
belonging to each modality. This effectively sets a hard limit over
`--limit-mm-per-prompt`.
target_sr (float, optional): Enables automatic resampling of audio
items to the model's expected sampling rate.
""" """
def __init__(self, *, target_sr: Optional[float] = None) -> None: def __init__(
self,
*,
max_mm_counts: Mapping[str, int] = {},
target_sr: Optional[float] = None,
) -> None:
super().__init__() super().__init__()
self.max_mm_counts = max_mm_counts
self.target_sr = target_sr self.target_sr = target_sr
def _is_embeddings(self, data: object) -> TypeGuard[NestedTensors]: def _is_embeddings(self, data: object) -> TypeGuard[NestedTensors]:
...@@ -332,6 +345,7 @@ class MultiModalDataParser: ...@@ -332,6 +345,7 @@ class MultiModalDataParser:
def parse_mm_data(self, def parse_mm_data(self,
mm_data: MultiModalDataDict) -> MultiModalDataItems: mm_data: MultiModalDataDict) -> MultiModalDataItems:
max_mm_counts = self.max_mm_counts
subparsers = self._get_subparsers() subparsers = self._get_subparsers()
mm_items = MultiModalDataItems() mm_items = MultiModalDataItems()
...@@ -339,6 +353,16 @@ class MultiModalDataParser: ...@@ -339,6 +353,16 @@ class MultiModalDataParser:
if k not in subparsers: if k not in subparsers:
raise ValueError(f"Unsupported modality: {k}") raise ValueError(f"Unsupported modality: {k}")
mm_items[k] = subparsers[k](v) modality_items = subparsers[k](v)
if k in max_mm_counts:
max_count = max_mm_counts[k]
if len(modality_items) > max_count:
raise ValueError(
f"This model supports at most {max_count} {k} items "
f"per prompt, but {len(modality_items)} {k} items "
"were given or set as its limit_mm_per_prompt.")
mm_items[k] = modality_items
return mm_items return mm_items
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment