Unverified Commit 27261e40 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Multi-video inference on LLaVA-Onevision (#15082)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatarIsotr0py <2037008807@qq.com>
parent e3f813c3
...@@ -25,7 +25,6 @@ from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, ...@@ -25,7 +25,6 @@ from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.multimodal.profiling import ProcessorInputs from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
...@@ -44,7 +43,7 @@ class LlavaOnevisionVideoPixelInputs(TypedDict): ...@@ -44,7 +43,7 @@ class LlavaOnevisionVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"] type: Literal["pixel_values_videos"]
pixel_values_videos: Union[torch.Tensor, list[torch.Tensor]] pixel_values_videos: Union[torch.Tensor, list[torch.Tensor]]
""" """
Shape: `(batch_size, num_videos, num_frames, num_channels, height, width)` Shape: `(batch_size * num_videos, num_frames, num_channels, height, width)`
Note that `num_videos` may be different for each batch, and 'num_frames' Note that `num_videos` may be different for each batch, and 'num_frames'
may be different for each video, in which case the data is passed as a may be different for each video, in which case the data is passed as a
...@@ -580,7 +579,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -580,7 +579,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return LlavaOnevisionVideoPixelInputs( return LlavaOnevisionVideoPixelInputs(
type="pixel_values_videos", type="pixel_values_videos",
pixel_values_videos=pixel_values_videos, pixel_values_videos=flatten_bn(pixel_values_videos),
) )
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
...@@ -768,22 +767,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -768,22 +767,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
for i, patch_features_batch in enumerate(patch_embeddings) for i, patch_features_batch in enumerate(patch_embeddings)
] ]
def _add_image_newline(
self,
video_features: torch.Tensor,
videos: int = 1,
frames: int = 1,
strategy: str = "one_token",
) -> torch.Tensor:
if strategy == "one_token":
video_features = video_features.reshape(
videos, frames * video_features.shape[1], -1)
image_newline = self.image_newline[None, None, :].repeat(
videos, 1, 1).to(video_features.device)
video_features = torch.cat((video_features, image_newline), dim=1)
return video_features
raise ValueError(f"Unexpected video newline strategy: {strategy}")
def _video_pixels_to_features( def _video_pixels_to_features(
self, self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel], vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
...@@ -807,33 +790,43 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -807,33 +790,43 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
video_pixels = inputs["pixel_values_videos"] video_pixels = inputs["pixel_values_videos"]
if isinstance(video_pixels, torch.Tensor): if isinstance(video_pixels, torch.Tensor):
b, num_videos, frames, c, h, w = video_pixels.shape total_videos, frames, c, h, w = video_pixels.shape
pixel_values = video_pixels.view(b * num_videos * frames, c, h, w) video_pixels_flat = video_pixels.view(total_videos * frames, c, h,
stacked_embeddings = self._video_pixels_to_features( w)
self.vision_tower, pixel_values)
stacked_embeddings = self._add_image_newline(stacked_embeddings, embeddings_flat = self._video_pixels_to_features(
videos=b * num_videos, self.vision_tower, video_pixels_flat)
frames=frames,
strategy="one_token") embeddings_flat = embeddings_flat.reshape(
return stacked_embeddings total_videos, frames * embeddings_flat.shape[1], -1)
elif is_list_of(video_pixels, torch.Tensor):
stacked_embeddings = [] image_newline = self.image_newline[None, None, :].expand(
for video_pixel in video_pixels: total_videos, -1, -1)
num_videos, frames, c, h, w = video_pixel.shape return torch.cat((embeddings_flat, image_newline), dim=1)
pixel_values = video_pixel.view(num_videos * frames, c, h, w)
embeddings = self._video_pixels_to_features( frames_per_video = [len(video) for video in video_pixels]
self.vision_tower, pixel_values) video_pixels_flat = torch.cat(video_pixels)
embeddings = self._add_image_newline(embeddings,
videos=num_videos, embeddings_flat = self._video_pixels_to_features(
frames=frames, self.vision_tower, video_pixels_flat)
strategy="one_token")
stacked_embeddings.append(embeddings) image_newline = self.image_newline[None, None, :]
return stacked_embeddings
else: return [
raise ValueError( torch.cat(
f"Unsupported type of video input {type(video_pixels)}") (
embeds.reshape(1, num_frame * embeddings_flat.shape[1],
-1),
image_newline,
),
dim=1,
) for num_frame, embeds in zip(
frames_per_video,
torch.split(embeddings_flat, frames_per_video),
)
]
def apply_pooling(self, image_features, stride=2): def apply_pooling(self, image_features: torch.Tensor, stride: int = 2):
vision_config = self.config.vision_config vision_config = self.config.vision_config
height = width = vision_config.image_size // vision_config.patch_size height = width = vision_config.image_size // vision_config.patch_size
batch_frames, _, dim = image_features.shape batch_frames, _, dim = image_features.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment