Unverified Commit de10ff0b authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate AyaVisionImagePixelInputs to TensorSchema for shape validation (#21622)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
parent 9d197280
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision # Adapted from https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union, cast from typing import Annotated, Literal, Optional, Union, cast
import torch import torch
from torch import nn from torch import nn
...@@ -29,6 +29,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -29,6 +29,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdateDetails) PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
...@@ -37,18 +38,28 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, ...@@ -37,18 +38,28 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
merge_multimodal_embeddings) merge_multimodal_embeddings)
class AyaVisionImagePixelInputs(TypedDict): class AyaVisionImagePixelInputs(TensorSchema):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
""" """
Shape: `(num_patches_total, num_channels, height, width)` Dimensions:
- np: The total number of patches over each image over each prompt in
`num_patches_total` is the total number of patches over each image over each the batch
prompt in the batch. - c: Number of channels
- h: Height of each image patch
- w: Width of each image patch
- bn: Batch size * number of images
""" """
num_patches: torch.Tensor type: Literal["pixel_values"]
"""Shape: `(batch_size * num_images)`"""
pixel_values: Annotated[
torch.Tensor,
TensorShape("np", 3, "h", "w"),
]
num_patches: Annotated[
torch.Tensor,
TensorShape("bn"),
]
class AyaVisionMultiModalProjector(nn.Module): class AyaVisionMultiModalProjector(nn.Module):
...@@ -383,21 +394,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -383,21 +394,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist()) e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist())
] ]
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
if d.shape != expected_dims:
raise ValueError(
"The expected shape of pixel values per image per batch "
f"is {expected_dims}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[AyaVisionImagePixelInputs]: self, **kwargs: object) -> Optional[AyaVisionImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
...@@ -405,22 +401,17 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -405,22 +401,17 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
image_embeds = kwargs.pop("image_embeds", None) image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Aya Vision does not support image_embeds." assert image_embeds is None, "Aya Vision does not support image_embeds."
if not isinstance(pixel_values, (torch.Tensor, list)): if pixel_values is None:
raise ValueError("Incorrect type of pixel values. " return None
f"Got type: {type(pixel_values)}")
if num_patches is not None and not isinstance(num_patches,
(torch.Tensor, list)):
raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_patches)}")
pixel_values = flatten_bn(pixel_values, concat=True)
num_patches = flatten_bn(num_patches, concat=True)
return AyaVisionImagePixelInputs( return AyaVisionImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values), pixel_values=flatten_bn(pixel_values, concat=True),
num_patches=num_patches, num_patches=flatten_bn(num_patches, concat=True),
) resolve_bindings={
"h": self.config.vision_config.image_size,
"w": self.config.vision_config.image_size,
})
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment