Unverified Commit 06da44f0 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate LlavaImageInputs to TensorSchema (#21770)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
parent a5549917
......@@ -3,7 +3,7 @@
from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
Union, cast)
import torch
......@@ -33,6 +33,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
......@@ -44,35 +45,46 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
from .vision import get_vision_encoder_info
class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
class LlavaImagePixelInputs(TensorSchema):
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height
- w: Width
Note that `height` or `width` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
class PixtralHFImagePixelInputs(TypedDict):
type: Literal["pixel_values_pixtral"]
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
class PixtralHFImagePixelInputs(TensorSchema):
"""
Shape: `(batch_size * num_images, num_channels, height, width)`
Dimensions:
- bn: Batch size * number of images
- c: Number of channels
- h: Height
- w: Width
Note that `height` or `width` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
type: Literal["pixel_values_pixtral"] = "pixel_values_pixtral"
pixel_values: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "c", "h", "w")]
class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
class LlavaImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- ifs: Image feature size
- hs: Hidden size (must match language model backbone)
"""
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
LlavaImageInputs = Union[LlavaImagePixelInputs, PixtralHFImagePixelInputs,
......@@ -547,19 +559,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
......@@ -579,10 +578,14 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
pixel_values=flatten_bn(pixel_values),
)
expected_h = expected_w = self.config.vision_config.image_size
return LlavaImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
pixel_values=flatten_bn(pixel_values, concat=True),
resolve_bindings={
"h": expected_h,
"w": expected_w
},
)
if image_embeds is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment