Unverified Commit a70d0bd0 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate LlavaOnevisionMultiInputs to TensorSchema (#21844)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
parent 24f4d1a2
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Final, Literal, Optional, Protocol, TypedDict, Union from typing import Annotated, Final, Literal, Optional, Protocol, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -11,7 +11,6 @@ from transformers import (BatchFeature, LlavaOnevisionConfig, ...@@ -11,7 +11,6 @@ from transformers import (BatchFeature, LlavaOnevisionConfig,
LlavaOnevisionProcessor) LlavaOnevisionProcessor)
from transformers.models.llava_onevision.modeling_llava_onevision import ( from transformers.models.llava_onevision.modeling_llava_onevision import (
get_anyres_image_grid_shape, unpad_image) get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
...@@ -23,6 +22,7 @@ from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, ...@@ -23,6 +22,7 @@ from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems) VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
...@@ -38,44 +38,62 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, ...@@ -38,44 +38,62 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
_MAX_FRAMES_PER_VIDEO = 16 _MAX_FRAMES_PER_VIDEO = 16
class LlavaOnevisionVideoPixelInputs(TypedDict): class LlavaOnevisionVideoPixelInputs(TensorSchema):
type: Literal["pixel_values_videos"]
pixel_values_videos: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: `(batch_size * num_videos, num_frames, num_channels, height, width)` Dimensions:
- bn: Batch size * number of videos
- f: Number of frames
- c: Number of channels (3)
- h: Height
- w: Width
Note that `num_videos` may be different for each batch, and 'num_frames' Note that `num_videos` may be different for each batch, and 'num_frames'
may be different for each video, in which case the data is passed as a may be different for each video, in which case the data is passed as a
list instead of a batched tensor. list instead of a batched tensor.
""" """
type: Literal["pixel_values_videos"] = "pixel_values_videos"
pixel_values_videos: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "f", 3, "h", "w", dynamic_dims={"f"}),
]
class LlavaOnevisionImagePixelInputs(TypedDict): class LlavaOnevisionImagePixelInputs(TensorSchema):
type: Literal["pixel_values"]
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: Dimensions:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)` - bn: Batch size * number of images
- np: Number of patches (1 + num_patches)
- c: Number of channels (3)
- h: Height
- w: Width
Note that `num_patches` may be different per batch and image, Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor. in which case the data is passed as a list instead of a batched tensor.
""" """
type: Literal["pixel_values"] = "pixel_values"
image_sizes: NotRequired[torch.Tensor] pixel_values: Annotated[
""" Union[torch.Tensor, list[torch.Tensor]],
Shape: `(batch_size * num_images, 2)` TensorShape("bn", "np", 3, "h", "w"),
]
This should be in `(height, width)` format.
"""
image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)]
class LlavaOnevisionImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone. class LlavaOnevisionImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- ifs: Image feature size
- hs: Hidden size (must match language model backbone)
""" """
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[
torch.Tensor,
TensorShape("bn", "ifs", "hs"),
]
LlavaOnevisionImageInputs = Union[LlavaOnevisionImagePixelInputs, LlavaOnevisionImageInputs = Union[LlavaOnevisionImagePixelInputs,
...@@ -482,44 +500,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -482,44 +500,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors) self.language_model.model.make_empty_intermediate_tensors)
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
expected_dims = (2, )
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
f"The expected shape of image sizes per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _validate_image_pixel_values(
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaOnevisionImageInputs]: self, **kwargs: object) -> Optional[LlavaOnevisionImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
...@@ -540,11 +520,12 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -540,11 +520,12 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return LlavaOnevisionImagePixelInputs( return LlavaOnevisionImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=self._validate_image_pixel_values( pixel_values=flatten_bn(pixel_values),
flatten_bn(pixel_values)), image_sizes=flatten_bn(image_sizes, concat=True),
image_sizes=self._validate_image_sizes( resolve_bindings={
flatten_bn(image_sizes, concat=True)), "h": self.config.vision_config.image_size,
) "w": self.config.vision_config.image_size
})
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor): if not isinstance(image_embeds, torch.Tensor):
...@@ -558,27 +539,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -558,27 +539,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
def _validate_video_pixel_values(
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[2:])
if actual_dims != expected_dims:
expected_expr = ("num_frames", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values in each video frame "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_video_input( def _parse_and_validate_video_input(
self, self,
**kwargs: object) -> Optional[LlavaOnevisionVideoPixelInputs]: **kwargs: object) -> Optional[LlavaOnevisionVideoPixelInputs]:
...@@ -600,7 +560,10 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -600,7 +560,10 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return LlavaOnevisionVideoPixelInputs( return LlavaOnevisionVideoPixelInputs(
type="pixel_values_videos", type="pixel_values_videos",
pixel_values_videos=flatten_bn(pixel_values_videos), pixel_values_videos=flatten_bn(pixel_values_videos),
) resolve_bindings={
"h": self.config.vision_config.image_size,
"w": self.config.vision_config.image_size
})
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
mm_input_by_modality = {} mm_input_by_modality = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment