Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
23db187d
Unverified
Commit
23db187d
authored
Mar 14, 2024
by
Joao Gante
Committed by
GitHub
Mar 14, 2024
Browse files
Generate: handle `cache_position` update in `generate` (#29467)
parent
7b87ecb0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
155 additions
and
78 deletions
+155
-78
src/transformers/cache_utils.py
src/transformers/cache_utils.py
+24
-14
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+63
-16
src/transformers/models/gemma/modeling_gemma.py
src/transformers/models/gemma/modeling_gemma.py
+32
-23
src/transformers/models/idefics/modeling_idefics.py
src/transformers/models/idefics/modeling_idefics.py
+4
-2
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+32
-23
No files found.
src/transformers/cache_utils.py
View file @
23db187d
...
...
@@ -4,6 +4,10 @@ from typing import Any, Dict, List, Optional, Tuple
import
torch
from
.configuration_utils
import
PretrainedConfig
from
.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
@
dataclass
...
...
@@ -57,6 +61,17 @@ class Cache:
return
max_length
-
new_seq_length
return
previous_seq_length
@
property
def
seen_tokens
(
self
):
logger
.
warning_once
(
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
"model input instead."
)
if
hasattr
(
self
,
"_seen_tokens"
):
return
self
.
_seen_tokens
else
:
return
None
class
DynamicCache
(
Cache
):
"""
...
...
@@ -69,7 +84,7 @@ class DynamicCache(Cache):
def
__init__
(
self
)
->
None
:
self
.
key_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
value_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
seen_tokens
=
0
# Used in `generate` to keep tally of how many tokens the cache has seen
self
.
_
seen_tokens
=
0
# Used in `generate` to keep tally of how many tokens the cache has seen
def
__getitem__
(
self
,
layer_idx
:
int
)
->
List
[
Tuple
[
torch
.
Tensor
]]:
"""
...
...
@@ -121,7 +136,7 @@ class DynamicCache(Cache):
"""
# Update the number of seen tokens
if
layer_idx
==
0
:
self
.
seen_tokens
+=
key_states
.
shape
[
-
2
]
self
.
_
seen_tokens
+=
key_states
.
shape
[
-
2
]
# Update the cache
if
len
(
self
.
key_cache
)
<=
layer_idx
:
...
...
@@ -191,7 +206,7 @@ class SinkCache(Cache):
self
.
window_length
=
window_length
self
.
num_sink_tokens
=
num_sink_tokens
self
.
cos_sin_cache
=
{}
self
.
seen_tokens
=
0
# Used in `generate` to keep tally of how many tokens the cache has seen
self
.
_
seen_tokens
=
0
# Used in `generate` to keep tally of how many tokens the cache has seen
@
staticmethod
def
_rotate_half
(
x
):
...
...
@@ -272,7 +287,7 @@ class SinkCache(Cache):
# Update the number of seen tokens
if
layer_idx
==
0
:
self
.
seen_tokens
+=
key_states
.
shape
[
-
2
]
self
.
_
seen_tokens
+=
key_states
.
shape
[
-
2
]
# [bsz, num_heads, seq_len, head_dim]
if
len
(
self
.
key_cache
)
<=
layer_idx
:
...
...
@@ -398,16 +413,11 @@ class StaticCache(Cache):
def
get_seq_length
(
self
,
layer_idx
:
Optional
[
int
]
=
0
)
->
int
:
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
# TODO: Fix once the stateful `int` bug in PyTorch is fixed.
raise
ValueError
(
"get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114."
)
def
get_usable_length
(
self
,
new_sequence_length
=
None
,
layer_idx
:
Optional
[
int
]
=
0
)
->
int
:
# TODO: Fix once the stateful `int` bug in PyTorch is fixed.
raise
ValueError
(
"get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114."
)
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
# https://github.com/pytorch/pytorch/issues/120248 is fixed
return
(
self
.
key_cache
[
0
,
0
].
any
(
dim
=-
1
)).
sum
()
def
get_max_length
(
self
)
->
Optional
[
int
]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
...
...
src/transformers/generation/utils.py
View file @
23db187d
...
...
@@ -633,7 +633,6 @@ class GenerationMixin:
model_kwargs
:
Dict
[
str
,
Any
],
is_encoder_decoder
:
bool
=
False
,
standardize_cache_format
:
bool
=
False
,
model_inputs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
Dict
[
str
,
Any
]:
# update past_key_values
model_kwargs
[
"past_key_values"
]
=
self
.
_extract_past_from_model_output
(
...
...
@@ -663,7 +662,8 @@ class GenerationMixin:
dim
=-
1
,
)
model_kwargs
[
"cache_position"
]
=
model_inputs
.
get
(
"cache_position"
,
None
)
if
"cache_position"
in
model_kwargs
and
model_kwargs
[
"cache_position"
]
is
not
None
:
model_kwargs
[
"cache_position"
]
=
model_kwargs
[
"cache_position"
][
-
1
:]
+
1
return
model_kwargs
...
...
@@ -1931,10 +1931,15 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
unfinished_sequences
=
torch
.
ones
(
input_ids
.
shape
[
0
],
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
batch_size
,
cur_len
=
(
model_kwargs
[
"attention_mask"
].
shape
if
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
not
None
else
input_ids
.
shape
)
unfinished_sequences
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
this_peer_finished
=
False
# used by synced_gpus only
batch_size
=
input_ids
.
shape
[
0
]
while
True
:
if
synced_gpus
:
...
...
@@ -1975,7 +1980,6 @@ class GenerationMixin:
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
standardize_cache_format
=
True
,
model_inputs
=
model_inputs
,
)
if
not
sequential
:
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
...
...
@@ -2170,7 +2174,9 @@ class GenerationMixin:
if
streamer
is
not
None
:
streamer
.
put
(
next_tokens
.
cpu
())
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_inputs
=
model_inputs
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
# if eos_token was found in one sentence, set sentence to finished
...
...
@@ -2389,7 +2395,13 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
unfinished_sequences
=
torch
.
ones
(
input_ids
.
shape
[
0
],
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
batch_size
,
cur_len
=
(
model_kwargs
[
"attention_mask"
].
shape
if
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
not
None
else
input_ids
.
shape
)
unfinished_sequences
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
this_peer_finished
=
False
# used by synced_gpus only
while
True
:
...
...
@@ -2459,7 +2471,6 @@ class GenerationMixin:
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_inputs
=
model_inputs
,
)
# if eos_token was found in one sentence, set sentence to finished
...
...
@@ -2688,7 +2699,13 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
unfinished_sequences
=
torch
.
ones
(
input_ids
.
shape
[
0
],
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
batch_size
,
cur_len
=
(
model_kwargs
[
"attention_mask"
].
shape
if
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
not
None
else
input_ids
.
shape
)
unfinished_sequences
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
this_peer_finished
=
False
# used by synced_gpus only
# auto-regressive generation
...
...
@@ -2758,7 +2775,9 @@ class GenerationMixin:
if
streamer
is
not
None
:
streamer
.
put
(
next_tokens
.
cpu
())
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_inputs
=
model_inputs
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
# if eos_token was found in one sentence, set sentence to finished
...
...
@@ -3003,6 +3022,7 @@ class GenerationMixin:
num_beams
=
beam_scorer
.
num_beams
batch_beam_size
,
cur_len
=
input_ids
.
shape
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
if
num_beams
*
batch_size
!=
batch_beam_size
:
raise
ValueError
(
...
...
@@ -3156,7 +3176,9 @@ class GenerationMixin:
input_ids
=
torch
.
cat
([
input_ids
[
beam_idx
,
:],
beam_next_tokens
.
unsqueeze
(
-
1
)],
dim
=-
1
)
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_inputs
=
model_inputs
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
if
model_kwargs
.
get
(
"past_key_values"
,
None
)
is
not
None
:
model_kwargs
[
"past_key_values"
]
=
self
.
_temporary_reorder_cache
(
...
...
@@ -3397,6 +3419,7 @@ class GenerationMixin:
num_beams
=
beam_scorer
.
num_beams
batch_beam_size
,
cur_len
=
input_ids
.
shape
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
# init attention / hidden states / scores tuples
scores
=
()
if
(
return_dict_in_generate
and
output_scores
)
else
None
...
...
@@ -3510,7 +3533,9 @@ class GenerationMixin:
input_ids
=
torch
.
cat
([
input_ids
[
beam_idx
,
:],
beam_next_tokens
.
unsqueeze
(
-
1
)],
dim
=-
1
)
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_inputs
=
model_inputs
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
if
model_kwargs
.
get
(
"past_key_values"
,
None
)
is
not
None
:
model_kwargs
[
"past_key_values"
]
=
self
.
_temporary_reorder_cache
(
...
...
@@ -3747,6 +3772,7 @@ class GenerationMixin:
device
=
input_ids
.
device
batch_beam_size
,
cur_len
=
input_ids
.
shape
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
if
return_dict_in_generate
and
output_scores
:
beam_indices
=
[
tuple
(()
for
_
in
range
(
num_sub_beams
*
batch_size
))
for
_
in
range
(
num_beam_groups
)]
...
...
@@ -3916,7 +3942,9 @@ class GenerationMixin:
input_ids
=
torch
.
cat
([
input_ids
,
current_tokens
.
unsqueeze
(
-
1
)],
dim
=-
1
)
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_inputs
=
model_inputs
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
if
model_kwargs
.
get
(
"past_key_values"
,
None
)
is
not
None
:
model_kwargs
[
"past_key_values"
]
=
self
.
_temporary_reorder_cache
(
...
...
@@ -4155,6 +4183,7 @@ class GenerationMixin:
num_beams
=
constrained_beam_scorer
.
num_beams
batch_beam_size
,
cur_len
=
input_ids
.
shape
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
if
num_beams
*
batch_size
!=
batch_beam_size
:
raise
ValueError
(
...
...
@@ -4275,7 +4304,9 @@ class GenerationMixin:
input_ids
=
torch
.
cat
([
input_ids
[
beam_idx
,
:],
beam_next_tokens
.
unsqueeze
(
-
1
)],
dim
=-
1
)
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_inputs
=
model_inputs
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
if
model_kwargs
.
get
(
"past_key_values"
,
None
)
is
not
None
:
model_kwargs
[
"past_key_values"
]
=
self
.
_temporary_reorder_cache
(
...
...
@@ -4511,7 +4542,13 @@ class GenerationMixin:
)
# keep track of which sequences are already finished
unfinished_sequences
=
input_ids
.
new
(
input_ids
.
shape
[
0
]).
fill_
(
1
)
batch_size
,
cur_len
=
batch_size
,
cur_len
=
(
model_kwargs
[
"attention_mask"
].
shape
if
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
not
None
else
input_ids
.
shape
)
unfinished_sequences
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
# other auxiliary variables
max_len
=
stopping_criteria
[
0
].
max_length
...
...
@@ -4555,6 +4592,14 @@ class GenerationMixin:
candidate_kwargs
,
candidate_input_ids
.
shape
[
1
],
self
.
config
.
is_encoder_decoder
)
candidate_kwargs
=
_prepare_token_type_ids
(
candidate_kwargs
,
candidate_input_ids
.
shape
[
1
])
if
"cache_position"
in
candidate_kwargs
:
candidate_kwargs
[
"cache_position"
]
=
torch
.
cat
(
(
candidate_kwargs
[
"cache_position"
],
torch
.
arange
(
cur_len
,
cur_len
+
candidate_length
,
device
=
input_ids
.
device
,
dtype
=
torch
.
long
),
),
dim
=
0
,
)
model_inputs
=
self
.
prepare_inputs_for_generation
(
candidate_input_ids
,
**
candidate_kwargs
)
...
...
@@ -4673,7 +4718,9 @@ class GenerationMixin:
)
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_inputs
=
model_inputs
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
# if eos_token was found in one sentence, set sentence to finished
...
...
src/transformers/models/gemma/modeling_gemma.py
View file @
23db187d
...
...
@@ -256,7 +256,7 @@ class GemmaAttention(nn.Module):
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
None
)
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; position
_ids
needed for the static cache
# sin and cos are specific to RoPE models;
cache_
position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
...
...
@@ -343,7 +343,7 @@ class GemmaFlashAttention2(GemmaAttention):
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; position
_ids
needed for the static cache
# sin and cos are specific to RoPE models;
cache_
position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
...
...
@@ -542,7 +542,7 @@ class GemmaSdpaAttention(GemmaAttention):
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; position
_ids
needed for the static cache
# sin and cos are specific to RoPE models;
cache_
position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
...
...
@@ -791,6 +791,10 @@ GEMMA_INPUTS_DOCSTRING = r"""
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
...
...
@@ -1128,14 +1132,26 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
)
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past_key_values
=
None
,
attention_mask
=
None
,
inputs_embeds
=
None
,
**
kwargs
self
,
input_ids
,
past_key_values
=
None
,
attention_mask
=
None
,
inputs_embeds
=
None
,
cache_position
=
None
,
**
kwargs
):
# With static cache, the `past_key_values` is None
# TODO joao: standardize interface for the different Cache classes and remove of this if
has_static_cache
=
False
if
past_key_values
is
None
:
past_key_values
=
getattr
(
self
.
model
.
layers
[
0
].
self_attn
,
"past_key_value"
,
None
)
has_static_cache
=
past_key_values
is
not
None
past_length
=
0
if
past_key_values
is
not
None
:
if
isinstance
(
past_key_values
,
Cache
):
cache_length
=
past_key_values
.
get_seq_length
()
past_length
=
past_key_values
.
seen_tokens
max_cache_length
=
past_key_values
.
get_max_length
()
past_length
=
cache_position
[
0
]
if
cache_position
is
not
None
else
past_key_values
.
get_seq_length
()
max_cache_length
=
(
torch
.
tensor
(
past_key_values
.
get_max_length
(),
device
=
input_ids
.
device
)
if
past_key_values
.
get_max_length
()
is
not
None
else
None
)
cache_length
=
past_length
if
max_cache_length
is
None
else
torch
.
min
(
max_cache_length
,
past_length
)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else
:
cache_length
=
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
max_cache_length
=
None
...
...
@@ -1168,22 +1184,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
if
past_key_values
:
position_ids
=
position_ids
[:,
-
input_ids
.
shape
[
1
]
:]
if
self
.
generation_config
.
cache_implementation
==
"static"
:
# generation with static cache
cache_position
=
kwargs
.
get
(
"cache_position"
,
None
)
if
cache_position
is
None
:
past_length
=
0
else
:
past_length
=
cache_position
[
-
1
]
+
1
input_ids
=
input_ids
[:,
past_length
:]
position_ids
=
position_ids
[:,
past_length
:]
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
input_length
=
position_ids
.
shape
[
-
1
]
if
position_ids
is
not
None
else
input_ids
.
shape
[
-
1
]
cache_position
=
torch
.
arange
(
past_length
,
past_length
+
input_length
,
device
=
input_ids
.
device
)
position_ids
=
position_ids
.
contiguous
()
if
position_ids
is
not
None
else
None
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if
inputs_embeds
is
not
None
and
past_key_values
is
None
:
model_inputs
=
{
"inputs_embeds"
:
inputs_embeds
}
...
...
@@ -1193,6 +1193,15 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
# TODO: use `next_tokens` directly instead.
model_inputs
=
{
"input_ids"
:
input_ids
.
contiguous
()}
input_length
=
position_ids
.
shape
[
-
1
]
if
position_ids
is
not
None
else
input_ids
.
shape
[
-
1
]
if
cache_position
is
None
:
cache_position
=
torch
.
arange
(
past_length
,
past_length
+
input_length
,
device
=
input_ids
.
device
)
else
:
cache_position
=
cache_position
[
-
input_length
:]
if
has_static_cache
:
past_key_values
=
None
model_inputs
.
update
(
{
"position_ids"
:
position_ids
,
...
...
src/transformers/models/idefics/modeling_idefics.py
View file @
23db187d
...
...
@@ -1557,10 +1557,12 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
model_kwargs
:
Dict
[
str
,
Any
],
is_encoder_decoder
:
bool
=
False
,
standardize_cache_format
:
bool
=
False
,
model_inputs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
Dict
[
str
,
Any
]:
model_kwargs
=
super
().
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
,
standardize_cache_format
,
model_inputs
outputs
,
model_kwargs
,
is_encoder_decoder
,
standardize_cache_format
,
)
if
"image_attention_mask"
in
model_kwargs
:
...
...
src/transformers/models/llama/modeling_llama.py
View file @
23db187d
...
...
@@ -361,7 +361,7 @@ class LlamaAttention(nn.Module):
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; position
_ids
needed for the static cache
# sin and cos are specific to RoPE models;
cache_
position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
...
...
@@ -451,7 +451,7 @@ class LlamaFlashAttention2(LlamaAttention):
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; position
_ids
needed for the static cache
# sin and cos are specific to RoPE models;
cache_
position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
...
...
@@ -650,7 +650,7 @@ class LlamaSdpaAttention(LlamaAttention):
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; position
_ids
needed for the static cache
# sin and cos are specific to RoPE models;
cache_
position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
...
...
@@ -903,6 +903,10 @@ LLAMA_INPUTS_DOCSTRING = r"""
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
...
...
@@ -1240,14 +1244,26 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
)
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past_key_values
=
None
,
attention_mask
=
None
,
inputs_embeds
=
None
,
**
kwargs
self
,
input_ids
,
past_key_values
=
None
,
attention_mask
=
None
,
inputs_embeds
=
None
,
cache_position
=
None
,
**
kwargs
):
# With static cache, the `past_key_values` is None
# TODO joao: standardize interface for the different Cache classes and remove of this if
has_static_cache
=
False
if
past_key_values
is
None
:
past_key_values
=
getattr
(
self
.
model
.
layers
[
0
].
self_attn
,
"past_key_value"
,
None
)
has_static_cache
=
past_key_values
is
not
None
past_length
=
0
if
past_key_values
is
not
None
:
if
isinstance
(
past_key_values
,
Cache
):
cache_length
=
past_key_values
.
get_seq_length
()
past_length
=
past_key_values
.
seen_tokens
max_cache_length
=
past_key_values
.
get_max_length
()
past_length
=
cache_position
[
0
]
if
cache_position
is
not
None
else
past_key_values
.
get_seq_length
()
max_cache_length
=
(
torch
.
tensor
(
past_key_values
.
get_max_length
(),
device
=
input_ids
.
device
)
if
past_key_values
.
get_max_length
()
is
not
None
else
None
)
cache_length
=
past_length
if
max_cache_length
is
None
else
torch
.
min
(
max_cache_length
,
past_length
)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else
:
cache_length
=
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
max_cache_length
=
None
...
...
@@ -1280,22 +1296,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
if
past_key_values
:
position_ids
=
position_ids
[:,
-
input_ids
.
shape
[
1
]
:]
if
self
.
generation_config
.
cache_implementation
==
"static"
:
# generation with static cache
cache_position
=
kwargs
.
get
(
"cache_position"
,
None
)
if
cache_position
is
None
:
past_length
=
0
else
:
past_length
=
cache_position
[
-
1
]
+
1
input_ids
=
input_ids
[:,
past_length
:]
position_ids
=
position_ids
[:,
past_length
:]
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
input_length
=
position_ids
.
shape
[
-
1
]
if
position_ids
is
not
None
else
input_ids
.
shape
[
-
1
]
cache_position
=
torch
.
arange
(
past_length
,
past_length
+
input_length
,
device
=
input_ids
.
device
)
position_ids
=
position_ids
.
contiguous
()
if
position_ids
is
not
None
else
None
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if
inputs_embeds
is
not
None
and
past_key_values
is
None
:
model_inputs
=
{
"inputs_embeds"
:
inputs_embeds
}
...
...
@@ -1305,6 +1305,15 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
# TODO: use `next_tokens` directly instead.
model_inputs
=
{
"input_ids"
:
input_ids
.
contiguous
()}
input_length
=
position_ids
.
shape
[
-
1
]
if
position_ids
is
not
None
else
input_ids
.
shape
[
-
1
]
if
cache_position
is
None
:
cache_position
=
torch
.
arange
(
past_length
,
past_length
+
input_length
,
device
=
input_ids
.
device
)
else
:
cache_position
=
cache_position
[
-
input_length
:]
if
has_static_cache
:
past_key_values
=
None
model_inputs
.
update
(
{
"position_ids"
:
position_ids
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment