Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d64e4da7
Unverified
Commit
d64e4da7
authored
Jun 04, 2024
by
Raushan Turganbay
Committed by
GitHub
Jun 04, 2024
Browse files
Video-LLaVa: handle any number of frames (#31221)
video-llava can handle more frames
parent
36ade4a3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
6 deletions
+8
-6
src/transformers/models/video_llava/modeling_video_llava.py
src/transformers/models/video_llava/modeling_video_llava.py
+5
-6
tests/models/video_llava/test_modeling_video_llava.py
tests/models/video_llava/test_modeling_video_llava.py
+3
-0
No files found.
src/transformers/models/video_llava/modeling_video_llava.py
View file @
d64e4da7
...
...
@@ -287,7 +287,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
num_images
,
num_image_patches
,
embed_dim
=
visual_features
.
shape
batch_size
,
sequence_length
=
input_ids
.
shape
left_padding
=
not
torch
.
sum
(
input_ids
[:,
-
1
]
==
torch
.
tensor
(
self
.
pad_token_id
))
special_vision_token
=
self
.
config
.
video_token_index
if
num_frames
==
8
else
self
.
config
.
image_token_index
special_vision_token
=
self
.
config
.
video_token_index
if
num_frames
>
1
else
self
.
config
.
image_token_index
# 1. Create a mask to know where special image tokens are
special_image_token_mask
=
input_ids
==
special_vision_token
...
...
@@ -375,14 +375,13 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
# videos do not need to select features and it's always "full" (as it is done in the orig implementation)
if
pixel_values_videos
is
not
None
:
batch_size_vid
,
num_frames
,
channels
,
height
,
width
=
pixel_values_videos
.
shape
if
num_frames
!=
8
:
raise
ValueError
(
f
"Video pixel values should have exactly `8` frames but foung `
{
num_frames
}
`"
)
pixel_values
=
pixel_values_videos
.
reshape
(
batch_size_vid
*
num_frames
,
channels
,
height
,
width
)
video_outputs
=
self
.
video_tower
(
pixel_values
,
output_hidden_states
=
True
)
video_outputs
=
video_outputs
.
hidden_states
[
vision_feature_layer
].
squeeze
(
1
)
else
:
video_outputs
=
None
num_frames
=
0
if
pixel_values_images
is
not
None
:
image_outputs
=
self
.
image_tower
(
pixel_values_images
,
output_hidden_states
=
True
)
...
...
@@ -397,7 +396,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
else
:
image_outputs
=
None
return
image_outputs
,
video_outputs
return
image_outputs
,
video_outputs
,
num_frames
@
add_start_docstrings_to_model_forward
(
VIDEO_LLAVA_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
VideoLlavaCausalLMOutputWithPast
,
config_class
=
_CONFIG_FOR_DOC
)
...
...
@@ -513,7 +512,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
# 2. Merge text and images
if
(
pixel_values_images
is
not
None
or
pixel_values_videos
is
not
None
)
and
input_ids
.
shape
[
1
]
!=
1
:
image_outputs
,
video_outputs
=
self
.
_get_vision_features
(
image_outputs
,
video_outputs
,
num_frames
=
self
.
_get_vision_features
(
pixel_values_images
=
pixel_values_images
,
pixel_values_videos
=
pixel_values_videos
,
vision_feature_layer
=
vision_feature_layer
,
...
...
@@ -546,7 +545,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
input_ids
,
attention_mask
,
labels
,
num_frames
=
8
,
num_frames
=
num_frames
,
)
else
:
# In case input_ids.shape[1] == 1 & past_key_values != None, we are in the case of
...
...
tests/models/video_llava/test_modeling_video_llava.py
View file @
d64e4da7
...
...
@@ -487,6 +487,9 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
repo_id
=
"raushan-testing-hf/videos-test"
,
filename
=
"video_demo.npy"
,
repo_type
=
"dataset"
)
video_file
=
np
.
load
(
video_file
)
# let's expand it for 16 frames, to check model can handle any number of frames
video_file
=
video_file
.
repeat
(
2
,
0
)
inputs
=
self
.
processor
(
prompt
,
videos
=
video_file
,
return_tensors
=
"pt"
).
to
(
torch_device
,
torch
.
float16
)
# Make sure that `generate` works
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment