Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
aa7ab98e
Unverified
Commit
aa7ab98e
authored
Dec 08, 2023
by
Arthur
Committed by
GitHub
Dec 08, 2023
Browse files
fix llava (#27909)
* fix llava * nits * attention_mask was forgotten * nice * :) * fixup
parent
e0b617d1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
50 additions
and
6 deletions
+50
-6
src/transformers/models/llava/modeling_llava.py
src/transformers/models/llava/modeling_llava.py
+50
-6
No files found.
src/transformers/models/llava/modeling_llava.py
View file @
aa7ab98e
...
@@ -22,6 +22,7 @@ from torch import nn
...
@@ -22,6 +22,7 @@ from torch import nn
from
...
import
PreTrainedModel
from
...
import
PreTrainedModel
from
...activations
import
ACT2FN
from
...activations
import
ACT2FN
from
...cache_utils
import
Cache
from
...modeling_outputs
import
ModelOutput
from
...modeling_outputs
import
ModelOutput
from
...utils
import
(
from
...utils
import
(
add_start_docstrings
,
add_start_docstrings
,
...
@@ -472,14 +473,57 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
...
@@ -472,14 +473,57 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
)
)
def
prepare_inputs_for_generation
(
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past_key_values
=
None
,
inputs_embeds
=
None
,
pixel_values
=
None
,
**
kwargs
self
,
input_ids
,
past_key_values
=
None
,
inputs_embeds
=
None
,
pixel_values
=
None
,
attention_mask
=
None
,
**
kwargs
):
):
# Call `prepare_inputs_for_generation` from the LM
if
past_key_values
is
not
None
:
model_input
=
self
.
language_model
.
prepare_inputs_for_generation
(
if
isinstance
(
past_key_values
,
Cache
):
input_ids
,
past_key_values
,
inputs_embeds
=
inputs_embeds
,
**
kwargs
cache_length
=
past_key_values
.
get_seq_length
()
past_length
=
past_key_values
.
seen_tokens
else
:
cache_length
=
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
# input)
if
attention_mask
is
not
None
and
attention_mask
.
shape
[
1
]
>
input_ids
.
shape
[
1
]:
input_ids
=
input_ids
[:,
-
(
attention_mask
.
shape
[
1
]
-
past_length
)
:]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif
past_length
<
input_ids
.
shape
[
1
]:
input_ids
=
input_ids
[:,
past_length
:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
elif
self
.
config
.
image_token_index
in
input_ids
:
input_ids
=
input_ids
[:,
input_ids
.
shape
[
1
]
-
1
:]
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# older attention values, as their corresponding values are not part of the input.
if
cache_length
<
past_length
and
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
-
(
cache_length
+
input_ids
.
shape
[
1
])
:]
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
if
attention_mask
is
not
None
and
position_ids
is
None
:
# create position_ids on the fly for batch generation
position_ids
=
attention_mask
.
long
().
cumsum
(
-
1
)
-
1
position_ids
.
masked_fill_
(
attention_mask
==
0
,
1
)
if
past_key_values
:
position_ids
=
position_ids
[:,
-
input_ids
.
shape
[
1
]
:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if
inputs_embeds
is
not
None
and
past_key_values
is
None
:
model_inputs
=
{
"inputs_embeds"
:
inputs_embeds
}
else
:
model_inputs
=
{
"input_ids"
:
input_ids
}
model_inputs
.
update
(
{
"position_ids"
:
position_ids
,
"past_key_values"
:
past_key_values
,
"use_cache"
:
kwargs
.
get
(
"use_cache"
),
"attention_mask"
:
attention_mask
,
"pixel_values"
:
pixel_values
,
}
)
)
model_input
.
update
({
"pixel_values"
:
pixel_values
})
return
model_inputs
return
model_input
def
_reorder_cache
(
self
,
*
args
,
**
kwargs
):
def
_reorder_cache
(
self
,
*
args
,
**
kwargs
):
return
self
.
language_model
.
_reorder_cache
(
*
args
,
**
kwargs
)
return
self
.
language_model
.
_reorder_cache
(
*
args
,
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment