Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
23db187d
"src/vscode:/vscode.git/clone" did not exist on "60ad9400cc56db9ade5c76cbc1117f4a334555f2"
Unverified
Commit
23db187d
authored
Mar 14, 2024
by
Joao Gante
Committed by
GitHub
Mar 14, 2024
Browse files
Generate: handle `cache_position` update in `generate` (#29467)
parent
7b87ecb0
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
155 additions
and
78 deletions
+155
-78
src/transformers/cache_utils.py
src/transformers/cache_utils.py
+24
-14
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+63
-16
src/transformers/models/gemma/modeling_gemma.py
src/transformers/models/gemma/modeling_gemma.py
+32
-23
src/transformers/models/idefics/modeling_idefics.py
src/transformers/models/idefics/modeling_idefics.py
+4
-2
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+32
-23
No files found.
src/transformers/cache_utils.py
View file @
23db187d
...
...
@@ -4,6 +4,10 @@ from typing import Any, Dict, List, Optional, Tuple
import
torch
from
.configuration_utils
import
PretrainedConfig
from
.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
@
dataclass
...
...
@@ -57,6 +61,17 @@ class Cache:
return
max_length
-
new_seq_length
return
previous_seq_length
@
property
def
seen_tokens
(
self
):
logger
.
warning_once
(
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
"model input instead."
)
if
hasattr
(
self
,
"_seen_tokens"
):
return
self
.
_seen_tokens
else
:
return
None
class
DynamicCache
(
Cache
):
"""
...
...
@@ -69,7 +84,7 @@ class DynamicCache(Cache):
def
__init__
(
self
)
->
None
:
self
.
key_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
value_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
seen_tokens
=
0
# Used in `generate` to keep tally of how many tokens the cache has seen
self
.
_
seen_tokens
=
0
# Used in `generate` to keep tally of how many tokens the cache has seen
def
__getitem__
(
self
,
layer_idx
:
int
)
->
List
[
Tuple
[
torch
.
Tensor
]]:
"""
...
...
@@ -121,7 +136,7 @@ class DynamicCache(Cache):
"""
# Update the number of seen tokens
if
layer_idx
==
0
:
self
.
seen_tokens
+=
key_states
.
shape
[
-
2
]
self
.
_
seen_tokens
+=
key_states
.
shape
[
-
2
]
# Update the cache
if
len
(
self
.
key_cache
)
<=
layer_idx
:
...
...
@@ -191,7 +206,7 @@ class SinkCache(Cache):
self
.
window_length
=
window_length
self
.
num_sink_tokens
=
num_sink_tokens
self
.
cos_sin_cache
=
{}
self
.
seen_tokens
=
0
# Used in `generate` to keep tally of how many tokens the cache has seen
self
.
_
seen_tokens
=
0
# Used in `generate` to keep tally of how many tokens the cache has seen
@
staticmethod
def
_rotate_half
(
x
):
...
...
@@ -272,7 +287,7 @@ class SinkCache(Cache):
# Update the number of seen tokens
if
layer_idx
==
0
:
self
.
seen_tokens
+=
key_states
.
shape
[
-
2
]
self
.
_
seen_tokens
+=
key_states
.
shape
[
-
2
]
# [bsz, num_heads, seq_len, head_dim]
if
len
(
self
.
key_cache
)
<=
layer_idx
:
...
...
@@ -398,16 +413,11 @@ class StaticCache(Cache):
def
get_seq_length
(
self
,
layer_idx
:
Optional
[
int
]
=
0
)
->
int
:
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
# TODO: Fix once the stateful `int` bug in PyTorch is fixed.
raise
ValueError
(
"get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114."
)
def
get_usable_length
(
self
,
new_sequence_length
=
None
,
layer_idx
:
Optional
[
int
]
=
0
)
->
int
:
# TODO: Fix once the stateful `int` bug in PyTorch is fixed.
raise
ValueError
(
"get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114."
)
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
# https://github.com/pytorch/pytorch/issues/120248 is fixed
return
(
self
.
key_cache
[
0
,
0
].
any
(
dim
=-
1
)).
sum
()
def
get_max_length
(
self
)
->
Optional
[
int
]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
...
...
src/transformers/generation/utils.py
View file @
23db187d
...
...
@@ -633,7 +633,6 @@ class GenerationMixin:
model_kwargs
:
Dict
[
str
,
Any
],
is_encoder_decoder
:
bool
=
False
,
standardize_cache_format
:
bool
=
False
,
model_inputs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
Dict
[
str
,
Any
]:
# update past_key_values
model_kwargs
[
"past_key_values"
]
=
self
.
_extract_past_from_model_output
(
...
...
@@ -663,7 +662,8 @@ class GenerationMixin:
dim
=-
1
,
)
model_kwargs
[
"cache_position"
]
=
model_inputs
.
get
(
"cache_position"
,
None
)
if
"cache_position"
in
model_kwargs
and
model_kwargs
[
"cache_position"
]
is
not
None
:
model_kwargs
[
"cache_position"
]
=
model_kwargs
[
"cache_position"
][
-
1
:]
+
1
return
model_kwargs
...
...
@@ -1931,10 +1931,15 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
unfinished_sequences
=
torch
.
ones
(
input_ids
.
shape
[
0
],
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
batch_size
,
cur_len
=
(
model_kwargs
[
"attention_mask"
].
shape
if
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
not
None
else
input_ids
.
shape
)
unfinished_sequences
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
this_peer_finished
=
False
# used by synced_gpus only
batch_size
=
input_ids
.
shape
[
0
]
while
True
:
if
synced_gpus
:
...
...
@@ -1975,7 +1980,6 @@ class GenerationMixin:
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
standardize_cache_format
=
True
,
model_inputs
=
model_inputs
,
)
if
not
sequential
:
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
...
...
@@ -2170,7 +2174,9 @@ class GenerationMixin:
if
streamer
is
not
None
:
streamer
.
put
(
next_tokens
.
cpu
())
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_inputs
=
model_inputs
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
# if eos_token was found in one sentence, set sentence to finished
...
...
@@ -2389,7 +2395,13 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
unfinished_sequences
=
torch
.
ones
(
input_ids
.
shape
[
0
],
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
batch_size
,
cur_len
=
(
model_kwargs
[
"attention_mask"
].
shape
if
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
not
None
else
input_ids
.
shape
)
unfinished_sequences
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
this_peer_finished
=
False
# used by synced_gpus only
while
True
:
...
...
@@ -2459,7 +2471,6 @@ class GenerationMixin:
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_inputs
=
model_inputs
,
)
# if eos_token was found in one sentence, set sentence to finished
...
...
@@ -2688,7 +2699,13 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
unfinished_sequences
=
torch
.
ones
(
input_ids
.
shape
[
0
],
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
batch_size
,
cur_len
=
(
model_kwargs
[
"attention_mask"
].
shape
if
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
not
None
else
input_ids
.
shape
)
unfinished_sequences
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
this_peer_finished
=
False
# used by synced_gpus only
# auto-regressive generation
...
...
@@ -2758,7 +2775,9 @@ class GenerationMixin:
if
streamer
is
not
None
:
streamer
.
put
(
next_tokens
.
cpu
())
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_inputs
=
model_inputs
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
# if eos_token was found in one sentence, set sentence to finished
...
...
@@ -3003,6 +3022,7 @@ class GenerationMixin:
num_beams
=
beam_scorer
.
num_beams
batch_beam_size
,
cur_len
=
input_ids
.
shape
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
if
num_beams
*
batch_size
!=
batch_beam_size
:
raise
ValueError
(
...
...
@@ -3156,7 +3176,9 @@ class GenerationMixin:
input_ids
=
torch
.
cat
([
input_ids
[
beam_idx
,
:],
beam_next_tokens
.
unsqueeze
(
-
1
)],
dim
=-
1
)
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_inputs
=
model_inputs
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
if
model_kwargs
.
get
(
"past_key_values"
,
None
)
is
not
None
:
model_kwargs
[
"past_key_values"
]
=
self
.
_temporary_reorder_cache
(
...
...
@@ -3397,6 +3419,7 @@ class GenerationMixin:
num_beams
=
beam_scorer
.
num_beams
batch_beam_size
,
cur_len
=
input_ids
.
shape
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
# init attention / hidden states / scores tuples
scores
=
()
if
(
return_dict_in_generate
and
output_scores
)
else
None
...
...
@@ -3510,7 +3533,9 @@ class GenerationMixin:
input_ids
=
torch
.
cat
([
input_ids
[
beam_idx
,
:],
beam_next_tokens
.
unsqueeze
(
-
1
)],
dim
=-
1
)
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_inputs
=
model_inputs
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
if
model_kwargs
.
get
(
"past_key_values"
,
None
)
is
not
None
:
model_kwargs
[
"past_key_values"
]
=
self
.
_temporary_reorder_cache
(
...
...
@@ -3747,6 +3772,7 @@ class GenerationMixin:
device
=
input_ids
.
device
batch_beam_size
,
cur_len
=
input_ids
.
shape
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
if
return_dict_in_generate
and
output_scores
:
beam_indices
=
[
tuple
(()
for
_
in
range
(
num_sub_beams
*
batch_size
))
for
_
in
range
(
num_beam_groups
)]
...
...
@@ -3916,7 +3942,9 @@ class GenerationMixin:
input_ids
=
torch
.
cat
([
input_ids
,
current_tokens
.
unsqueeze
(
-
1
)],
dim
=-
1
)
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_inputs
=
model_inputs
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
if
model_kwargs
.
get
(
"past_key_values"
,
None
)
is
not
None
:
model_kwargs
[
"past_key_values"
]
=
self
.
_temporary_reorder_cache
(
...
...
@@ -4155,6 +4183,7 @@ class GenerationMixin:
num_beams
=
constrained_beam_scorer
.
num_beams
batch_beam_size
,
cur_len
=
input_ids
.
shape
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
if
num_beams
*
batch_size
!=
batch_beam_size
:
raise
ValueError
(
...
...
@@ -4275,7 +4304,9 @@ class GenerationMixin:
input_ids
=
torch
.
cat
([
input_ids
[
beam_idx
,
:],
beam_next_tokens
.
unsqueeze
(
-
1
)],
dim
=-
1
)
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_inputs
=
model_inputs
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
if
model_kwargs
.
get
(
"past_key_values"
,
None
)
is
not
None
:
model_kwargs
[
"past_key_values"
]
=
self
.
_temporary_reorder_cache
(
...
...
@@ -4511,7 +4542,13 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
unfinished_sequences
=
input_ids
.
new
(
input_ids
.
shape
[
0
]).
fill_
(
1
)
batch_size
,
cur_len
=
batch_size
,
cur_len
=
(
model_kwargs
[
"attention_mask"
].
shape
if
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
not
None
else
input_ids
.
shape
)
unfinished_sequences
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
# other auxiliary variables
max_len
=
stopping_criteria
[
0
].
max_length
...
...
@@ -4555,6 +4592,14 @@ class GenerationMixin:
candidate_kwargs
,
candidate_input_ids
.
shape
[
1
],
self
.
config
.
is_encoder_decoder
)
candidate_kwargs
=
_prepare_token_type_ids
(
candidate_kwargs
,
candidate_input_ids
.
shape
[
1
])
if
"cache_position"
in
candidate_kwargs
:
candidate_kwargs
[
"cache_position"
]
=
torch
.
cat
(
(
candidate_kwargs
[
"cache_position"
],
torch
.
arange
(
cur_len
,
cur_len
+
candidate_length
,
device
=
input_ids
.
device
,
dtype
=
torch
.
long
),
),
dim
=
0
,
)
model_inputs
=
self
.
prepare_inputs_for_generation
(
candidate_input_ids
,
**
candidate_kwargs
)
...
...
@@ -4673,7 +4718,9 @@ class GenerationMixin:
)
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_inputs
=
model_inputs
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
# if eos_token was found in one sentence, set sentence to finished
...
...
src/transformers/models/gemma/modeling_gemma.py
View file @
23db187d
...
...
@@ -256,7 +256,7 @@ class GemmaAttention(nn.Module):
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
None
)
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; position
_ids
needed for the static cache
# sin and cos are specific to RoPE models;
cache_
position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
...
...
@@ -343,7 +343,7 @@ class GemmaFlashAttention2(GemmaAttention):
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; position
_ids
needed for the static cache
# sin and cos are specific to RoPE models;
cache_
position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
...
...
@@ -542,7 +542,7 @@ class GemmaSdpaAttention(GemmaAttention):
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; position
_ids
needed for the static cache
# sin and cos are specific to RoPE models;
cache_
position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
...
...
@@ -791,6 +791,10 @@ GEMMA_INPUTS_DOCSTRING = r"""
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
...
...
@@ -1128,14 +1132,26 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
)
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past_key_values
=
None
,
attention_mask
=
None
,
inputs_embeds
=
None
,
**
kwargs
self
,
input_ids
,
past_key_values
=
None
,
attention_mask
=
None
,
inputs_embeds
=
None
,
cache_position
=
None
,
**
kwargs
):
# With static cache, the `past_key_values` is None
# TODO joao: standardize interface for the different Cache classes and remove of this if
has_static_cache
=
False
if
past_key_values
is
None
:
past_key_values
=
getattr
(
self
.
model
.
layers
[
0
].
self_attn
,
"past_key_value"
,
None
)
has_static_cache
=
past_key_values
is
not
None
past_length
=
0
if
past_key_values
is
not
None
:
if
isinstance
(
past_key_values
,
Cache
):
cache_length
=
past_key_values
.
get_seq_length
()
past_length
=
past_key_values
.
seen_tokens
max_cache_length
=
past_key_values
.
get_max_length
()
past_length
=
cache_position
[
0
]
if
cache_position
is
not
None
else
past_key_values
.
get_seq_length
()
max_cache_length
=
(
torch
.
tensor
(
past_key_values
.
get_max_length
(),
device
=
input_ids
.
device
)
if
past_key_values
.
get_max_length
()
is
not
None
else
None
)
cache_length
=
past_length
if
max_cache_length
is
None
else
torch
.
min
(
max_cache_length
,
past_length
)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else
:
cache_length
=
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
max_cache_length
=
None
...
...
@@ -1168,22 +1184,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
if
past_key_values
:
position_ids
=
position_ids
[:,
-
input_ids
.
shape
[
1
]
:]
if
self
.
generation_config
.
cache_implementation
==
"static"
:
# generation with static cache
cache_position
=
kwargs
.
get
(
"cache_position"
,
None
)
if
cache_position
is
None
:
past_length
=
0
else
:
past_length
=
cache_position
[
-
1
]
+
1
input_ids
=
input_ids
[:,
past_length
:]
position_ids
=
position_ids
[:,
past_length
:]
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
input_length
=
position_ids
.
shape
[
-
1
]
if
position_ids
is
not
None
else
input_ids
.
shape
[
-
1
]
cache_position
=
torch
.
arange
(
past_length
,
past_length
+
input_length
,
device
=
input_ids
.
device
)
position_ids
=
position_ids
.
contiguous
()
if
position_ids
is
not
None
else
None
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if
inputs_embeds
is
not
None
and
past_key_values
is
None
:
model_inputs
=
{
"inputs_embeds"
:
inputs_embeds
}
...
...
@@ -1193,6 +1193,15 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
# TODO: use `next_tokens` directly instead.
model_inputs
=
{
"input_ids"
:
input_ids
.
contiguous
()}
input_length
=
position_ids
.
shape
[
-
1
]
if
position_ids
is
not
None
else
input_ids
.
shape
[
-
1
]
if
cache_position
is
None
:
cache_position
=
torch
.
arange
(
past_length
,
past_length
+
input_length
,
device
=
input_ids
.
device
)
else
:
cache_position
=
cache_position
[
-
input_length
:]
if
has_static_cache
:
past_key_values
=
None
model_inputs
.
update
(
{
"position_ids"
:
position_ids
,
...
...
src/transformers/models/idefics/modeling_idefics.py
View file @
23db187d
...
...
@@ -1557,10 +1557,12 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
model_kwargs
:
Dict
[
str
,
Any
],
is_encoder_decoder
:
bool
=
False
,
standardize_cache_format
:
bool
=
False
,
model_inputs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
Dict
[
str
,
Any
]:
model_kwargs
=
super
().
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
,
standardize_cache_format
,
model_inputs
outputs
,
model_kwargs
,
is_encoder_decoder
,
standardize_cache_format
,
)
if
"image_attention_mask"
in
model_kwargs
:
...
...
src/transformers/models/llama/modeling_llama.py
View file @
23db187d
...
...
@@ -361,7 +361,7 @@ class LlamaAttention(nn.Module):
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; position
_ids
needed for the static cache
# sin and cos are specific to RoPE models;
cache_
position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
...
...
@@ -451,7 +451,7 @@ class LlamaFlashAttention2(LlamaAttention):
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; position
_ids
needed for the static cache
# sin and cos are specific to RoPE models;
cache_
position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
...
...
@@ -650,7 +650,7 @@ class LlamaSdpaAttention(LlamaAttention):
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; position
_ids
needed for the static cache
# sin and cos are specific to RoPE models;
cache_
position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
...
...
@@ -903,6 +903,10 @@ LLAMA_INPUTS_DOCSTRING = r"""
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
...
...
@@ -1240,14 +1244,26 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
)
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past_key_values
=
None
,
attention_mask
=
None
,
inputs_embeds
=
None
,
**
kwargs
self
,
input_ids
,
past_key_values
=
None
,
attention_mask
=
None
,
inputs_embeds
=
None
,
cache_position
=
None
,
**
kwargs
):
# With static cache, the `past_key_values` is None
# TODO joao: standardize interface for the different Cache classes and remove of this if
has_static_cache
=
False
if
past_key_values
is
None
:
past_key_values
=
getattr
(
self
.
model
.
layers
[
0
].
self_attn
,
"past_key_value"
,
None
)
has_static_cache
=
past_key_values
is
not
None
past_length
=
0
if
past_key_values
is
not
None
:
if
isinstance
(
past_key_values
,
Cache
):
cache_length
=
past_key_values
.
get_seq_length
()
past_length
=
past_key_values
.
seen_tokens
max_cache_length
=
past_key_values
.
get_max_length
()
past_length
=
cache_position
[
0
]
if
cache_position
is
not
None
else
past_key_values
.
get_seq_length
()
max_cache_length
=
(
torch
.
tensor
(
past_key_values
.
get_max_length
(),
device
=
input_ids
.
device
)
if
past_key_values
.
get_max_length
()
is
not
None
else
None
)
cache_length
=
past_length
if
max_cache_length
is
None
else
torch
.
min
(
max_cache_length
,
past_length
)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else
:
cache_length
=
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
max_cache_length
=
None
...
...
@@ -1280,22 +1296,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
if
past_key_values
:
position_ids
=
position_ids
[:,
-
input_ids
.
shape
[
1
]
:]
if
self
.
generation_config
.
cache_implementation
==
"static"
:
# generation with static cache
cache_position
=
kwargs
.
get
(
"cache_position"
,
None
)
if
cache_position
is
None
:
past_length
=
0
else
:
past_length
=
cache_position
[
-
1
]
+
1
input_ids
=
input_ids
[:,
past_length
:]
position_ids
=
position_ids
[:,
past_length
:]
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
input_length
=
position_ids
.
shape
[
-
1
]
if
position_ids
is
not
None
else
input_ids
.
shape
[
-
1
]
cache_position
=
torch
.
arange
(
past_length
,
past_length
+
input_length
,
device
=
input_ids
.
device
)
position_ids
=
position_ids
.
contiguous
()
if
position_ids
is
not
None
else
None
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if
inputs_embeds
is
not
None
and
past_key_values
is
None
:
model_inputs
=
{
"inputs_embeds"
:
inputs_embeds
}
...
...
@@ -1305,6 +1305,15 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
# TODO: use `next_tokens` directly instead.
model_inputs
=
{
"input_ids"
:
input_ids
.
contiguous
()}
input_length
=
position_ids
.
shape
[
-
1
]
if
position_ids
is
not
None
else
input_ids
.
shape
[
-
1
]
if
cache_position
is
None
:
cache_position
=
torch
.
arange
(
past_length
,
past_length
+
input_length
,
device
=
input_ids
.
device
)
else
:
cache_position
=
cache_position
[
-
input_length
:]
if
has_static_cache
:
past_key_values
=
None
model_inputs
.
update
(
{
"position_ids"
:
position_ids
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment