"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "72c6e8b8bf89c97f9ad1c02cb5d8161d5ef9caad"
Unverified Commit d64e4da7 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Video-LLaVa: handle any number of frames (#31221)

video-llava can handle more frames
parent 36ade4a3
...@@ -287,7 +287,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel): ...@@ -287,7 +287,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
num_images, num_image_patches, embed_dim = visual_features.shape num_images, num_image_patches, embed_dim = visual_features.shape
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
special_vision_token = self.config.video_token_index if num_frames == 8 else self.config.image_token_index special_vision_token = self.config.video_token_index if num_frames > 1 else self.config.image_token_index
# 1. Create a mask to know where special image tokens are # 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == special_vision_token special_image_token_mask = input_ids == special_vision_token
...@@ -375,14 +375,13 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel): ...@@ -375,14 +375,13 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
# videos do not need to select features and it's always "full" (as it is done in the orig implementation) # videos do not need to select features and it's always "full" (as it is done in the orig implementation)
if pixel_values_videos is not None: if pixel_values_videos is not None:
batch_size_vid, num_frames, channels, height, width = pixel_values_videos.shape batch_size_vid, num_frames, channels, height, width = pixel_values_videos.shape
if num_frames != 8:
raise ValueError(f"Video pixel values should have exactly `8` frames but foung `{num_frames}`")
pixel_values = pixel_values_videos.reshape(batch_size_vid * num_frames, channels, height, width) pixel_values = pixel_values_videos.reshape(batch_size_vid * num_frames, channels, height, width)
video_outputs = self.video_tower(pixel_values, output_hidden_states=True) video_outputs = self.video_tower(pixel_values, output_hidden_states=True)
video_outputs = video_outputs.hidden_states[vision_feature_layer].squeeze(1) video_outputs = video_outputs.hidden_states[vision_feature_layer].squeeze(1)
else: else:
video_outputs = None video_outputs = None
num_frames = 0
if pixel_values_images is not None: if pixel_values_images is not None:
image_outputs = self.image_tower(pixel_values_images, output_hidden_states=True) image_outputs = self.image_tower(pixel_values_images, output_hidden_states=True)
...@@ -397,7 +396,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel): ...@@ -397,7 +396,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
else: else:
image_outputs = None image_outputs = None
return image_outputs, video_outputs return image_outputs, video_outputs, num_frames
@add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=VideoLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=VideoLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
...@@ -513,7 +512,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel): ...@@ -513,7 +512,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
# 2. Merge text and images # 2. Merge text and images
if (pixel_values_images is not None or pixel_values_videos is not None) and input_ids.shape[1] != 1: if (pixel_values_images is not None or pixel_values_videos is not None) and input_ids.shape[1] != 1:
image_outputs, video_outputs = self._get_vision_features( image_outputs, video_outputs, num_frames = self._get_vision_features(
pixel_values_images=pixel_values_images, pixel_values_images=pixel_values_images,
pixel_values_videos=pixel_values_videos, pixel_values_videos=pixel_values_videos,
vision_feature_layer=vision_feature_layer, vision_feature_layer=vision_feature_layer,
...@@ -546,7 +545,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel): ...@@ -546,7 +545,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
input_ids, input_ids,
attention_mask, attention_mask,
labels, labels,
num_frames=8, num_frames=num_frames,
) )
else: else:
# In case input_ids.shape[1] == 1 & past_key_values != None, we are in the case of # In case input_ids.shape[1] == 1 & past_key_values != None, we are in the case of
......
...@@ -487,6 +487,9 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -487,6 +487,9 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset" repo_id="raushan-testing-hf/videos-test", filename="video_demo.npy", repo_type="dataset"
) )
video_file = np.load(video_file) video_file = np.load(video_file)
# let's expand it for 16 frames, to check model can handle any number of frames
video_file = video_file.repeat(2, 0)
inputs = self.processor(prompt, videos=video_file, return_tensors="pt").to(torch_device, torch.float16) inputs = self.processor(prompt, videos=video_file, return_tensors="pt").to(torch_device, torch.float16)
# Make sure that `generate` works # Make sure that `generate` works
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment