Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
aa7ab98e
Unverified
Commit
aa7ab98e
authored
Dec 08, 2023
by
Arthur
Committed by
GitHub
Dec 08, 2023
Browse files
fix llava (#27909)
* fix llava * nits * attention_mask was forgotten * nice * :) * fixup
parent
e0b617d1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
50 additions
and
6 deletions
+50
-6
src/transformers/models/llava/modeling_llava.py
src/transformers/models/llava/modeling_llava.py
+50
-6
No files found.
src/transformers/models/llava/modeling_llava.py
View file @
aa7ab98e
...
@@ -22,6 +22,7 @@ from torch import nn
...
@@ -22,6 +22,7 @@ from torch import nn
from
...
import
PreTrainedModel
from
...
import
PreTrainedModel
from
...activations
import
ACT2FN
from
...activations
import
ACT2FN
from
...cache_utils
import
Cache
from
...modeling_outputs
import
ModelOutput
from
...modeling_outputs
import
ModelOutput
from
...utils
import
(
from
...utils
import
(
add_start_docstrings
,
add_start_docstrings
,
...
@@ -472,14 +473,57 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
...
@@ -472,14 +473,57 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
)
)
def
prepare_inputs_for_generation
(
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past_key_values
=
None
,
inputs_embeds
=
None
,
pixel_values
=
None
,
**
kwargs
self
,
input_ids
,
past_key_values
=
None
,
inputs_embeds
=
None
,
pixel_values
=
None
,
attention_mask
=
None
,
**
kwargs
):
):
# Call `prepare_inputs_for_generation` from the LM
if
past_key_values
is
not
None
:
model_input
=
self
.
language_model
.
prepare_inputs_for_generation
(
if
isinstance
(
past_key_values
,
Cache
):
input_ids
,
past_key_values
,
inputs_embeds
=
inputs_embeds
,
**
kwargs
cache_length
=
past_key_values
.
get_seq_length
()
past_length
=
past_key_values
.
seen_tokens
else
:
cache_length
=
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
# input)
if
attention_mask
is
not
None
and
attention_mask
.
shape
[
1
]
>
input_ids
.
shape
[
1
]:
input_ids
=
input_ids
[:,
-
(
attention_mask
.
shape
[
1
]
-
past_length
)
:]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif
past_length
<
input_ids
.
shape
[
1
]:
input_ids
=
input_ids
[:,
past_length
:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
elif
self
.
config
.
image_token_index
in
input_ids
:
input_ids
=
input_ids
[:,
input_ids
.
shape
[
1
]
-
1
:]
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# older attention values, as their corresponding values are not part of the input.
if
cache_length
<
past_length
and
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
-
(
cache_length
+
input_ids
.
shape
[
1
])
:]
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
if
attention_mask
is
not
None
and
position_ids
is
None
:
# create position_ids on the fly for batch generation
position_ids
=
attention_mask
.
long
().
cumsum
(
-
1
)
-
1
position_ids
.
masked_fill_
(
attention_mask
==
0
,
1
)
if
past_key_values
:
position_ids
=
position_ids
[:,
-
input_ids
.
shape
[
1
]
:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if
inputs_embeds
is
not
None
and
past_key_values
is
None
:
model_inputs
=
{
"inputs_embeds"
:
inputs_embeds
}
else
:
model_inputs
=
{
"input_ids"
:
input_ids
}
model_inputs
.
update
(
{
"position_ids"
:
position_ids
,
"past_key_values"
:
past_key_values
,
"use_cache"
:
kwargs
.
get
(
"use_cache"
),
"attention_mask"
:
attention_mask
,
"pixel_values"
:
pixel_values
,
}
)
)
model_input
.
update
({
"pixel_values"
:
pixel_values
})
return
model_inputs
return
model_input
def
_reorder_cache
(
self
,
*
args
,
**
kwargs
):
def
_reorder_cache
(
self
,
*
args
,
**
kwargs
):
return
self
.
language_model
.
_reorder_cache
(
*
args
,
**
kwargs
)
return
self
.
language_model
.
_reorder_cache
(
*
args
,
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment