Unverified Commit 9d197280 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate AriaImagePixelInputs to TensorSchema for shape validation (#21620)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
parent e98def43
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Optional, TypedDict, Union from typing import Annotated, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -29,6 +29,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -29,6 +29,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdate) PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
# yapf: disable # yapf: disable
from .idefics2_vision_model import Idefics2VisionConfig from .idefics2_vision_model import Idefics2VisionConfig
...@@ -42,15 +43,26 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, ...@@ -42,15 +43,26 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
merge_multimodal_embeddings) merge_multimodal_embeddings)
class AriaImagePixelInputs(TypedDict): class AriaImagePixelInputs(TensorSchema):
pixel_values: torch.Tensor
pixel_mask: Optional[torch.Tensor]
""" """
Shape: Dimensions:
pixel_values: `(batch_size * num_images, num_channels, height, width)` - b: Batch size
pixel_mask: `(batch_size * num_images, height, width)` - n: Number of images
- c: Number of channels
- h: Height of each image
- w: Width of each image
""" """
pixel_values: Annotated[
torch.Tensor,
TensorShape("bn", 3, "h", "w"),
]
pixel_mask: Annotated[
Optional[torch.Tensor],
TensorShape("bn", "h", "w"),
]
class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant): class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant):
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
...@@ -540,12 +552,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -540,12 +552,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.vocab_size, logit_scale) self.vocab_size, logit_scale)
def _validate_image_sizes(
self, images: list[torch.Tensor]) -> list[torch.Tensor]:
if not all(img.shape == images[0].shape for img in images):
raise ValueError("All images must be the same size")
return images
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[AriaImagePixelInputs]: self, **kwargs: object) -> Optional[AriaImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
...@@ -554,23 +560,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -554,23 +560,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
if pixel_values is None: if pixel_values is None:
return None return None
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
pixel_values = self._validate_image_sizes(pixel_values)
pixel_values = flatten_bn(pixel_values, concat=True)
if pixel_mask is not None:
if not isinstance(pixel_mask, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel mask. "
f"Got type: {type(pixel_mask)}")
pixel_mask = flatten_bn(pixel_mask, concat=True)
return AriaImagePixelInputs( return AriaImagePixelInputs(
pixel_values=pixel_values, pixel_values=flatten_bn(pixel_values, concat=True),
pixel_mask=pixel_mask, pixel_mask=flatten_bn(pixel_mask, concat=True),
) )
def _create_patch_attention_mask( def _create_patch_attention_mask(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment