Unverified Commit b4e29167 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate LlavaNextImageInputs to TensorSchema (#21774)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 65a7917b
......@@ -3,7 +3,7 @@
from abc import abstractmethod
from collections.abc import Iterable, Mapping
from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
Union)
import torch
......@@ -11,7 +11,6 @@ import torch.nn as nn
from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired
from vllm.config import VllmConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -19,6 +18,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.multimodal.parse import ImageSize
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
......@@ -30,32 +30,36 @@ from .utils import (AutoWeightsLoader, WeightsMapper, embed_multimodal,
flatten_bn, init_vllm_registered_model, maybe_prefix)
class LlavaNextImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
class LlavaNextImagePixelInputs(TensorSchema):
"""
Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
Dimensions:
- bn: Batch size * number of images
- np: Number of patches + 1
- c: Number of channels (3)
- h: Height
- w: Width
Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"})]
image_sizes: NotRequired[torch.Tensor]
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(height, width)` format.
"""
image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)]
# This should be in `(height, width)` format.
class LlavaNextImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
class LlavaNextImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- ifs: Image feature size
- hs: Hidden size (must match language model backbone)
"""
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
......@@ -269,44 +273,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
expected_dims = (2, )
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
f"The expected shape of image sizes per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _validate_pixel_values(
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
......@@ -325,13 +291,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
expected_h = expected_w = self.config.vision_config.image_size
return LlavaNextImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_pixel_values(
flatten_bn(pixel_values)),
image_sizes=self._validate_image_sizes(
flatten_bn(image_sizes, concat=True)),
)
pixel_values=flatten_bn(pixel_values),
image_sizes=flatten_bn(image_sizes, concat=True),
resolve_bindings={
"h": expected_h,
"w": expected_w,
})
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
......
......@@ -60,6 +60,9 @@ class TensorSchema:
def __getitem__(self, item) -> Any:
return getattr(self, item)
def get(self, item, default=None) -> Any:
return getattr(self, item, default)
def _match_shape_with_dynamic(self, actual: tuple[int, ...],
reference: tuple[int, ...],
expected_shape: tuple[Union[int, str], ...],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment