Unverified Commit b4e29167 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate LlavaNextImageInputs to TensorSchema (#21774)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 65a7917b
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar, from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
Union) Union)
import torch import torch
...@@ -11,7 +11,6 @@ import torch.nn as nn ...@@ -11,7 +11,6 @@ import torch.nn as nn
from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor
from transformers.models.llava_next.modeling_llava_next import ( from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image) get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -19,6 +18,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -19,6 +18,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.multimodal.parse import ImageSize from vllm.multimodal.parse import ImageSize
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
...@@ -30,32 +30,36 @@ from .utils import (AutoWeightsLoader, WeightsMapper, embed_multimodal, ...@@ -30,32 +30,36 @@ from .utils import (AutoWeightsLoader, WeightsMapper, embed_multimodal,
flatten_bn, init_vllm_registered_model, maybe_prefix) flatten_bn, init_vllm_registered_model, maybe_prefix)
class LlavaNextImagePixelInputs(TypedDict): class LlavaNextImagePixelInputs(TensorSchema):
type: Literal["pixel_values"]
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: Dimensions:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)` - bn: Batch size * number of images
- np: Number of patches + 1
- c: Number of channels (3)
- h: Height
- w: Width
Note that `num_patches` may be different per batch and image, Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor. in which case the data is passed as a list instead of a batched tensor.
""" """
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"})]
image_sizes: NotRequired[torch.Tensor] image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)]
""" # This should be in `(height, width)` format.
Shape: `(batch_size * num_images, 2)`
This should be in `(height, width)` format.
"""
class LlavaNextImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone. class LlavaNextImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- ifs: Image feature size
- hs: Hidden size (must match language model backbone)
""" """
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
LlavaNextImageInputs = Union[LlavaNextImagePixelInputs, LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
...@@ -269,44 +273,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -269,44 +273,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
expected_dims = (2, )
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
f"The expected shape of image sizes per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _validate_pixel_values(
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaNextImageInputs]: self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
...@@ -325,13 +291,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -325,13 +291,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of image sizes. " raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}") f"Got type: {type(image_sizes)}")
expected_h = expected_w = self.config.vision_config.image_size
return LlavaNextImagePixelInputs( return LlavaNextImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=self._validate_pixel_values( pixel_values=flatten_bn(pixel_values),
flatten_bn(pixel_values)), image_sizes=flatten_bn(image_sizes, concat=True),
image_sizes=self._validate_image_sizes( resolve_bindings={
flatten_bn(image_sizes, concat=True)), "h": expected_h,
) "w": expected_w,
})
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor): if not isinstance(image_embeds, torch.Tensor):
......
...@@ -60,6 +60,9 @@ class TensorSchema: ...@@ -60,6 +60,9 @@ class TensorSchema:
def __getitem__(self, item) -> Any: def __getitem__(self, item) -> Any:
return getattr(self, item) return getattr(self, item)
def get(self, item, default=None) -> Any:
return getattr(self, item, default)
def _match_shape_with_dynamic(self, actual: tuple[int, ...], def _match_shape_with_dynamic(self, actual: tuple[int, ...],
reference: tuple[int, ...], reference: tuple[int, ...],
expected_shape: tuple[Union[int, str], ...], expected_shape: tuple[Union[int, str], ...],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment