Unverified Commit a5549917 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Migrate LlavaNextVideoPixelInputs to TensorSchema (#21843)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
parent d1af8b7b
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union from typing import Annotated, Literal, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -25,6 +25,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -25,6 +25,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava from .llava import init_vision_tower_for_llava
...@@ -35,17 +36,25 @@ from .utils import (AutoWeightsLoader, WeightsMapper, ...@@ -35,17 +36,25 @@ from .utils import (AutoWeightsLoader, WeightsMapper,
from .vision import get_vision_encoder_info from .vision import get_vision_encoder_info
class LlavaNextVideoPixelInputs(TypedDict): class LlavaNextVideoPixelInputs(TensorSchema):
type: Literal["pixel_values_videos"] """
data: Union[torch.Tensor, list[torch.Tensor]] Dimensions:
""" - bs: Batch size
Shape: `(batch_size, num_frames, num_channels, height, width)` - nv: Number of videos
- nf: Number of frames
- nc: Number of channels (3)
- h: Height of each frame
- w: Width of each frame
Note that `num_frames` may be different for each batch, in which case Note that `num_frames` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor. the data is passed as a list instead of a batched tensor.
Note that it only supports one video input for one batch. Note that it only supports one video input for one batch.
""" """
type: Literal["pixel_values_videos"] = "pixel_values_videos"
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bs", "nv", "nf", 3, "h", "w")]
class LlavaNextVideoProcessingInfo(BaseProcessingInfo): class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
...@@ -320,27 +329,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -320,27 +329,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors) self.language_model.model.make_empty_intermediate_tensors)
def _validate_video_pixel_values(
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[2:])
if actual_dims != expected_dims:
expected_expr = ("num_frames", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values in each video frame "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_video_input( def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[LlavaNextVideoPixelInputs]: self, **kwargs: object) -> Optional[LlavaNextVideoPixelInputs]:
""" """
...@@ -355,14 +343,13 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -355,14 +343,13 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
if pixel_values_videos is None: if pixel_values_videos is None:
return None return None
if not isinstance(pixel_values_videos, (torch.Tensor, list)): expected_h = expected_w = self.config.vision_config.image_size
raise ValueError("Incorrect type of pixel_values_videos. " return LlavaNextVideoPixelInputs(type="pixel_values_videos",
f"Got type: {type(pixel_values_videos)}") data=pixel_values_videos,
resolve_bindings={
return LlavaNextVideoPixelInputs( "h": expected_h,
type="pixel_values_videos", "w": expected_w,
data=pixel_values_videos, })
)
def _select_image_features(self, image_features: torch.Tensor, *, def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor: strategy: str) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment