Unverified Commit f1e2c095 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate InternVLImageInputs and InternVLVideoInputs to TensorSchema (#21684)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
parent 12a223ef
......@@ -9,7 +9,7 @@
# --------------------------------------------------------
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, TypedDict, TypeVar, Union
from typing import Annotated, Any, Literal, Optional, TypeVar, Union
import numpy.typing as npt
import torch
......@@ -37,6 +37,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
......@@ -51,54 +52,60 @@ IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
class InternVLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values_flat: torch.Tensor
class InternVLImagePixelInputs(TensorSchema):
"""
Shape:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
Dimensions:
- bn: Batch size * number of images
- bnp: Batch size * number of images * (1 + num_patches)
- c: Number of channels (3)
- h: Height of each image patch
- w: Width of each image patch
"""
type: Literal["pixel_values"]
pixel_values_flat: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
class InternVLImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- n: Number of images
- f: Total image feature size
- h: Hidden size (must match the hidden size of language model backbone)
"""
type: Literal["image_embeds"]
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("n", "f", "h")]
InternVLImageInputs = Union[InternVLImagePixelInputs,
InternVLImageEmbeddingInputs]
class InternVLVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"]
pixel_values_flat: torch.Tensor
class InternVLVideoPixelInputs(TensorSchema):
"""
Shape:
`(batch_size * num_video * num_frames, num_channels, height, width)`
Dimensions:
- bvf: Batch size * number of videos * num_frames
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height of each video frame
- w: Width of each video frame
"""
type: Literal["pixel_values_videos"]
pixel_values_flat: Annotated[torch.Tensor, TensorShape("bvf", 3, "h", "w")]
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
class InternVLVideoEmbeddingInputs(TypedDict):
type: Literal["video_embeds"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""
A tensor of shape `(num_videos, total_video_feature_size, hidden_size)`
or a list of tensors of shape `(total_video_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
class InternVLVideoEmbeddingInputs(TensorSchema):
"""
Dimensions:
- n: Number of videos
- f: Total video feature size
- h: Hidden size (must match the hidden size of language model backbone)
"""
type: Literal["video_embeds"]
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("n", "f", "h")]
InternVLVideoInputs = Union[InternVLVideoPixelInputs,
......@@ -1151,26 +1158,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
vit_embeds = self.mlp1(vit_embeds)
return vit_embeds
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[InternVLImageInputs]:
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
......@@ -1205,12 +1192,14 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
expected_h = expected_w = self.config.vision_config.image_size
resolve_bindings = {"h": expected_h, "w": expected_w}
return InternVLImagePixelInputs(
type="pixel_values",
pixel_values_flat=self._validate_pixel_values(
pixel_values_flat),
pixel_values_flat=pixel_values_flat,
num_patches=image_num_patches,
resolve_bindings=resolve_bindings,
)
raise AssertionError("This line should be unreachable.")
......@@ -1225,11 +1214,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
return None
if video_embeds is not None:
if not isinstance(video_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of video embeddings. "
f"Got type: {type(video_embeds)}")
return InternVLImageEmbeddingInputs(
return InternVLVideoEmbeddingInputs(
type="video_embeds",
data=flatten_bn(video_embeds),
)
......@@ -1250,12 +1235,14 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
pixel_values_flat_video = flatten_bn(pixel_values_flat_video,
concat=True)
video_num_patches = flatten_bn(video_num_patches, concat=True)
expected_h = expected_w = self.config.vision_config.image_size
resolve_bindings = {"h": expected_h, "w": expected_w}
return InternVLVideoPixelInputs(
type="pixel_values_videos",
pixel_values_flat=self._validate_pixel_values(
pixel_values_flat_video),
pixel_values_flat=pixel_values_flat_video,
num_patches=video_num_patches,
resolve_bindings=resolve_bindings,
)
raise AssertionError("This line should be unreachable.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment