Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d64e4da7
Unverified
Commit
d64e4da7
authored
Jun 04, 2024
by
Raushan Turganbay
Committed by
GitHub
Jun 04, 2024
Browse files
Video-LLaVa: handle any number of frames (#31221)
video-llava can handle more frames
parent
36ade4a3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
6 deletions
+8
-6
src/transformers/models/video_llava/modeling_video_llava.py
src/transformers/models/video_llava/modeling_video_llava.py
+5
-6
tests/models/video_llava/test_modeling_video_llava.py
tests/models/video_llava/test_modeling_video_llava.py
+3
-0
No files found.
src/transformers/models/video_llava/modeling_video_llava.py
View file @
d64e4da7
...
@@ -287,7 +287,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
...
@@ -287,7 +287,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
num_images
,
num_image_patches
,
embed_dim
=
visual_features
.
shape
num_images
,
num_image_patches
,
embed_dim
=
visual_features
.
shape
batch_size
,
sequence_length
=
input_ids
.
shape
batch_size
,
sequence_length
=
input_ids
.
shape
left_padding
=
not
torch
.
sum
(
input_ids
[:,
-
1
]
==
torch
.
tensor
(
self
.
pad_token_id
))
left_padding
=
not
torch
.
sum
(
input_ids
[:,
-
1
]
==
torch
.
tensor
(
self
.
pad_token_id
))
special_vision_token
=
self
.
config
.
video_token_index
if
num_frames
==
8
else
self
.
config
.
image_token_index
special_vision_token
=
self
.
config
.
video_token_index
if
num_frames
>
1
else
self
.
config
.
image_token_index
# 1. Create a mask to know where special image tokens are
# 1. Create a mask to know where special image tokens are
special_image_token_mask
=
input_ids
==
special_vision_token
special_image_token_mask
=
input_ids
==
special_vision_token
...
@@ -375,14 +375,13 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
...
@@ -375,14 +375,13 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
# videos do not need to select features and it's always "full" (as it is done in the orig implementation)
# videos do not need to select features and it's always "full" (as it is done in the orig implementation)
if
pixel_values_videos
is
not
None
:
if
pixel_values_videos
is
not
None
:
batch_size_vid
,
num_frames
,
channels
,
height
,
width
=
pixel_values_videos
.
shape
batch_size_vid
,
num_frames
,
channels
,
height
,
width
=
pixel_values_videos
.
shape
if
num_frames
!=
8
:
raise
ValueError
(
f
"Video pixel values should have exactly `8` frames but foung `
{
num_frames
}
`"
)
pixel_values
=
pixel_values_videos
.
reshape
(
batch_size_vid
*
num_frames
,
channels
,
height
,
width
)
pixel_values
=
pixel_values_videos
.
reshape
(
batch_size_vid
*
num_frames
,
channels
,
height
,
width
)
video_outputs
=
self
.
video_tower
(
pixel_values
,
output_hidden_states
=
True
)
video_outputs
=
self
.
video_tower
(
pixel_values
,
output_hidden_states
=
True
)
video_outputs
=
video_outputs
.
hidden_states
[
vision_feature_layer
].
squeeze
(
1
)
video_outputs
=
video_outputs
.
hidden_states
[
vision_feature_layer
].
squeeze
(
1
)
else
:
else
:
video_outputs
=
None
video_outputs
=
None
num_frames
=
0
if
pixel_values_images
is
not
None
:
if
pixel_values_images
is
not
None
:
image_outputs
=
self
.
image_tower
(
pixel_values_images
,
output_hidden_states
=
True
)
image_outputs
=
self
.
image_tower
(
pixel_values_images
,
output_hidden_states
=
True
)
...
@@ -397,7 +396,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
...
@@ -397,7 +396,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
else
:
else
:
image_outputs
=
None
image_outputs
=
None
return
image_outputs
,
video_outputs
return
image_outputs
,
video_outputs
,
num_frames
@
add_start_docstrings_to_model_forward
(
VIDEO_LLAVA_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
VIDEO_LLAVA_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
VideoLlavaCausalLMOutputWithPast
,
config_class
=
_CONFIG_FOR_DOC
)
@
replace_return_docstrings
(
output_type
=
VideoLlavaCausalLMOutputWithPast
,
config_class
=
_CONFIG_FOR_DOC
)
...
@@ -513,7 +512,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
...
@@ -513,7 +512,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
# 2. Merge text and images
# 2. Merge text and images
if
(
pixel_values_images
is
not
None
or
pixel_values_videos
is
not
None
)
and
input_ids
.
shape
[
1
]
!=
1
:
if
(
pixel_values_images
is
not
None
or
pixel_values_videos
is
not
None
)
and
input_ids
.
shape
[
1
]
!=
1
:
image_outputs
,
video_outputs
=
self
.
_get_vision_features
(
image_outputs
,
video_outputs
,
num_frames
=
self
.
_get_vision_features
(
pixel_values_images
=
pixel_values_images
,
pixel_values_images
=
pixel_values_images
,
pixel_values_videos
=
pixel_values_videos
,
pixel_values_videos
=
pixel_values_videos
,
vision_feature_layer
=
vision_feature_layer
,
vision_feature_layer
=
vision_feature_layer
,
...
@@ -546,7 +545,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
...
@@ -546,7 +545,7 @@ class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel):
input_ids
,
input_ids
,
attention_mask
,
attention_mask
,
labels
,
labels
,
num_frames
=
8
,
num_frames
=
num_frames
,
)
)
else
:
else
:
# In case input_ids.shape[1] == 1 & past_key_values != None, we are in the case of
# In case input_ids.shape[1] == 1 & past_key_values != None, we are in the case of
...
...
tests/models/video_llava/test_modeling_video_llava.py
View file @
d64e4da7
...
@@ -487,6 +487,9 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
...
@@ -487,6 +487,9 @@ class VideoLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
repo_id
=
"raushan-testing-hf/videos-test"
,
filename
=
"video_demo.npy"
,
repo_type
=
"dataset"
repo_id
=
"raushan-testing-hf/videos-test"
,
filename
=
"video_demo.npy"
,
repo_type
=
"dataset"
)
)
video_file
=
np
.
load
(
video_file
)
video_file
=
np
.
load
(
video_file
)
# let's expand it for 16 frames, to check model can handle any number of frames
video_file
=
video_file
.
repeat
(
2
,
0
)
inputs
=
self
.
processor
(
prompt
,
videos
=
video_file
,
return_tensors
=
"pt"
).
to
(
torch_device
,
torch
.
float16
)
inputs
=
self
.
processor
(
prompt
,
videos
=
video_file
,
return_tensors
=
"pt"
).
to
(
torch_device
,
torch
.
float16
)
# Make sure that `generate` works
# Make sure that `generate` works
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment