Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
23db187d
Unverified
Commit
23db187d
authored
Mar 14, 2024
by
Joao Gante
Committed by
GitHub
Mar 14, 2024
Browse files
Generate: handle `cache_position` update in `generate` (#29467)
parent
7b87ecb0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
155 additions
and
78 deletions
+155
-78
src/transformers/cache_utils.py
src/transformers/cache_utils.py
+24
-14
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+63
-16
src/transformers/models/gemma/modeling_gemma.py
src/transformers/models/gemma/modeling_gemma.py
+32
-23
src/transformers/models/idefics/modeling_idefics.py
src/transformers/models/idefics/modeling_idefics.py
+4
-2
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+32
-23
No files found.
src/transformers/cache_utils.py
View file @
23db187d
...
@@ -4,6 +4,10 @@ from typing import Any, Dict, List, Optional, Tuple
...
@@ -4,6 +4,10 @@ from typing import Any, Dict, List, Optional, Tuple
import
torch
import
torch
from
.configuration_utils
import
PretrainedConfig
from
.configuration_utils
import
PretrainedConfig
from
.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
@
dataclass
@
dataclass
...
@@ -57,6 +61,17 @@ class Cache:
...
@@ -57,6 +61,17 @@ class Cache:
return
max_length
-
new_seq_length
return
max_length
-
new_seq_length
return
previous_seq_length
return
previous_seq_length
@
property
def
seen_tokens
(
self
):
logger
.
warning_once
(
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
"model input instead."
)
if
hasattr
(
self
,
"_seen_tokens"
):
return
self
.
_seen_tokens
else
:
return
None
class
DynamicCache
(
Cache
):
class
DynamicCache
(
Cache
):
"""
"""
...
@@ -69,7 +84,7 @@ class DynamicCache(Cache):
...
@@ -69,7 +84,7 @@ class DynamicCache(Cache):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
self
.
key_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
key_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
value_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
value_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
seen_tokens
=
0
# Used in `generate` to keep tally of how many tokens the cache has seen
self
.
_
seen_tokens
=
0
# Used in `generate` to keep tally of how many tokens the cache has seen
def
__getitem__
(
self
,
layer_idx
:
int
)
->
List
[
Tuple
[
torch
.
Tensor
]]:
def
__getitem__
(
self
,
layer_idx
:
int
)
->
List
[
Tuple
[
torch
.
Tensor
]]:
"""
"""
...
@@ -121,7 +136,7 @@ class DynamicCache(Cache):
...
@@ -121,7 +136,7 @@ class DynamicCache(Cache):
"""
"""
# Update the number of seen tokens
# Update the number of seen tokens
if
layer_idx
==
0
:
if
layer_idx
==
0
:
self
.
seen_tokens
+=
key_states
.
shape
[
-
2
]
self
.
_
seen_tokens
+=
key_states
.
shape
[
-
2
]
# Update the cache
# Update the cache
if
len
(
self
.
key_cache
)
<=
layer_idx
:
if
len
(
self
.
key_cache
)
<=
layer_idx
:
...
@@ -191,7 +206,7 @@ class SinkCache(Cache):
...
@@ -191,7 +206,7 @@ class SinkCache(Cache):
self
.
window_length
=
window_length
self
.
window_length
=
window_length
self
.
num_sink_tokens
=
num_sink_tokens
self
.
num_sink_tokens
=
num_sink_tokens
self
.
cos_sin_cache
=
{}
self
.
cos_sin_cache
=
{}
self
.
seen_tokens
=
0
# Used in `generate` to keep tally of how many tokens the cache has seen
self
.
_
seen_tokens
=
0
# Used in `generate` to keep tally of how many tokens the cache has seen
@
staticmethod
@
staticmethod
def
_rotate_half
(
x
):
def
_rotate_half
(
x
):
...
@@ -272,7 +287,7 @@ class SinkCache(Cache):
...
@@ -272,7 +287,7 @@ class SinkCache(Cache):
# Update the number of seen tokens
# Update the number of seen tokens
if
layer_idx
==
0
:
if
layer_idx
==
0
:
self
.
seen_tokens
+=
key_states
.
shape
[
-
2
]
self
.
_
seen_tokens
+=
key_states
.
shape
[
-
2
]
# [bsz, num_heads, seq_len, head_dim]
# [bsz, num_heads, seq_len, head_dim]
if
len
(
self
.
key_cache
)
<=
layer_idx
:
if
len
(
self
.
key_cache
)
<=
layer_idx
:
...
@@ -398,16 +413,11 @@ class StaticCache(Cache):
...
@@ -398,16 +413,11 @@ class StaticCache(Cache):
def
get_seq_length
(
self
,
layer_idx
:
Optional
[
int
]
=
0
)
->
int
:
def
get_seq_length
(
self
,
layer_idx
:
Optional
[
int
]
=
0
)
->
int
:
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
# TODO: Fix once the stateful `int` bug in PyTorch is fixed.
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
raise
ValueError
(
# limit the check to the first batch member and head dimension.
"get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114."
# TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
)
# https://github.com/pytorch/pytorch/issues/120248 is fixed
return
(
self
.
key_cache
[
0
,
0
].
any
(
dim
=-
1
)).
sum
()
def
get_usable_length
(
self
,
new_sequence_length
=
None
,
layer_idx
:
Optional
[
int
]
=
0
)
->
int
:
# TODO: Fix once the stateful `int` bug in PyTorch is fixed.
raise
ValueError
(
"get_seq_length is not implemented for StaticCache. Please refer to https://github.com/huggingface/transformers/pull/29114."
)
def
get_max_length
(
self
)
->
Optional
[
int
]:
def
get_max_length
(
self
)
->
Optional
[
int
]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
...
...
src/transformers/generation/utils.py
View file @
23db187d
...
@@ -633,7 +633,6 @@ class GenerationMixin:
...
@@ -633,7 +633,6 @@ class GenerationMixin:
model_kwargs
:
Dict
[
str
,
Any
],
model_kwargs
:
Dict
[
str
,
Any
],
is_encoder_decoder
:
bool
=
False
,
is_encoder_decoder
:
bool
=
False
,
standardize_cache_format
:
bool
=
False
,
standardize_cache_format
:
bool
=
False
,
model_inputs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
Dict
[
str
,
Any
]:
)
->
Dict
[
str
,
Any
]:
# update past_key_values
# update past_key_values
model_kwargs
[
"past_key_values"
]
=
self
.
_extract_past_from_model_output
(
model_kwargs
[
"past_key_values"
]
=
self
.
_extract_past_from_model_output
(
...
@@ -663,7 +662,8 @@ class GenerationMixin:
...
@@ -663,7 +662,8 @@ class GenerationMixin:
dim
=-
1
,
dim
=-
1
,
)
)
model_kwargs
[
"cache_position"
]
=
model_inputs
.
get
(
"cache_position"
,
None
)
if
"cache_position"
in
model_kwargs
and
model_kwargs
[
"cache_position"
]
is
not
None
:
model_kwargs
[
"cache_position"
]
=
model_kwargs
[
"cache_position"
][
-
1
:]
+
1
return
model_kwargs
return
model_kwargs
...
@@ -1931,10 +1931,15 @@ class GenerationMixin:
...
@@ -1931,10 +1931,15 @@ class GenerationMixin:
)
)
# keep track of which sequences are already finished
# keep track of which sequences are already finished
unfinished_sequences
=
torch
.
ones
(
input_ids
.
shape
[
0
],
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
batch_size
,
cur_len
=
(
model_kwargs
[
"attention_mask"
].
shape
if
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
not
None
else
input_ids
.
shape
)
unfinished_sequences
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
this_peer_finished
=
False
# used by synced_gpus only
this_peer_finished
=
False
# used by synced_gpus only
batch_size
=
input_ids
.
shape
[
0
]
while
True
:
while
True
:
if
synced_gpus
:
if
synced_gpus
:
...
@@ -1975,7 +1980,6 @@ class GenerationMixin:
...
@@ -1975,7 +1980,6 @@ class GenerationMixin:
model_kwargs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
standardize_cache_format
=
True
,
standardize_cache_format
=
True
,
model_inputs
=
model_inputs
,
)
)
if
not
sequential
:
if
not
sequential
:
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
...
@@ -2170,7 +2174,9 @@ class GenerationMixin:
...
@@ -2170,7 +2174,9 @@ class GenerationMixin:
if
streamer
is
not
None
:
if
streamer
is
not
None
:
streamer
.
put
(
next_tokens
.
cpu
())
streamer
.
put
(
next_tokens
.
cpu
())
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_inputs
=
model_inputs
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
)
# if eos_token was found in one sentence, set sentence to finished
# if eos_token was found in one sentence, set sentence to finished
...
@@ -2389,7 +2395,13 @@ class GenerationMixin:
...
@@ -2389,7 +2395,13 @@ class GenerationMixin:
)
)
# keep track of which sequences are already finished
# keep track of which sequences are already finished
unfinished_sequences
=
torch
.
ones
(
input_ids
.
shape
[
0
],
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
batch_size
,
cur_len
=
(
model_kwargs
[
"attention_mask"
].
shape
if
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
not
None
else
input_ids
.
shape
)
unfinished_sequences
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
this_peer_finished
=
False
# used by synced_gpus only
this_peer_finished
=
False
# used by synced_gpus only
while
True
:
while
True
:
...
@@ -2459,7 +2471,6 @@ class GenerationMixin:
...
@@ -2459,7 +2471,6 @@ class GenerationMixin:
outputs
,
outputs
,
model_kwargs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_inputs
=
model_inputs
,
)
)
# if eos_token was found in one sentence, set sentence to finished
# if eos_token was found in one sentence, set sentence to finished
...
@@ -2688,7 +2699,13 @@ class GenerationMixin:
...
@@ -2688,7 +2699,13 @@ class GenerationMixin:
)
)
# keep track of which sequences are already finished
# keep track of which sequences are already finished
unfinished_sequences
=
torch
.
ones
(
input_ids
.
shape
[
0
],
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
batch_size
,
cur_len
=
(
model_kwargs
[
"attention_mask"
].
shape
if
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
not
None
else
input_ids
.
shape
)
unfinished_sequences
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
this_peer_finished
=
False
# used by synced_gpus only
this_peer_finished
=
False
# used by synced_gpus only
# auto-regressive generation
# auto-regressive generation
...
@@ -2758,7 +2775,9 @@ class GenerationMixin:
...
@@ -2758,7 +2775,9 @@ class GenerationMixin:
if
streamer
is
not
None
:
if
streamer
is
not
None
:
streamer
.
put
(
next_tokens
.
cpu
())
streamer
.
put
(
next_tokens
.
cpu
())
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_inputs
=
model_inputs
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
)
# if eos_token was found in one sentence, set sentence to finished
# if eos_token was found in one sentence, set sentence to finished
...
@@ -3003,6 +3022,7 @@ class GenerationMixin:
...
@@ -3003,6 +3022,7 @@ class GenerationMixin:
num_beams
=
beam_scorer
.
num_beams
num_beams
=
beam_scorer
.
num_beams
batch_beam_size
,
cur_len
=
input_ids
.
shape
batch_beam_size
,
cur_len
=
input_ids
.
shape
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
if
num_beams
*
batch_size
!=
batch_beam_size
:
if
num_beams
*
batch_size
!=
batch_beam_size
:
raise
ValueError
(
raise
ValueError
(
...
@@ -3156,7 +3176,9 @@ class GenerationMixin:
...
@@ -3156,7 +3176,9 @@ class GenerationMixin:
input_ids
=
torch
.
cat
([
input_ids
[
beam_idx
,
:],
beam_next_tokens
.
unsqueeze
(
-
1
)],
dim
=-
1
)
input_ids
=
torch
.
cat
([
input_ids
[
beam_idx
,
:],
beam_next_tokens
.
unsqueeze
(
-
1
)],
dim
=-
1
)
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_inputs
=
model_inputs
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
)
if
model_kwargs
.
get
(
"past_key_values"
,
None
)
is
not
None
:
if
model_kwargs
.
get
(
"past_key_values"
,
None
)
is
not
None
:
model_kwargs
[
"past_key_values"
]
=
self
.
_temporary_reorder_cache
(
model_kwargs
[
"past_key_values"
]
=
self
.
_temporary_reorder_cache
(
...
@@ -3397,6 +3419,7 @@ class GenerationMixin:
...
@@ -3397,6 +3419,7 @@ class GenerationMixin:
num_beams
=
beam_scorer
.
num_beams
num_beams
=
beam_scorer
.
num_beams
batch_beam_size
,
cur_len
=
input_ids
.
shape
batch_beam_size
,
cur_len
=
input_ids
.
shape
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
# init attention / hidden states / scores tuples
# init attention / hidden states / scores tuples
scores
=
()
if
(
return_dict_in_generate
and
output_scores
)
else
None
scores
=
()
if
(
return_dict_in_generate
and
output_scores
)
else
None
...
@@ -3510,7 +3533,9 @@ class GenerationMixin:
...
@@ -3510,7 +3533,9 @@ class GenerationMixin:
input_ids
=
torch
.
cat
([
input_ids
[
beam_idx
,
:],
beam_next_tokens
.
unsqueeze
(
-
1
)],
dim
=-
1
)
input_ids
=
torch
.
cat
([
input_ids
[
beam_idx
,
:],
beam_next_tokens
.
unsqueeze
(
-
1
)],
dim
=-
1
)
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_inputs
=
model_inputs
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
)
if
model_kwargs
.
get
(
"past_key_values"
,
None
)
is
not
None
:
if
model_kwargs
.
get
(
"past_key_values"
,
None
)
is
not
None
:
model_kwargs
[
"past_key_values"
]
=
self
.
_temporary_reorder_cache
(
model_kwargs
[
"past_key_values"
]
=
self
.
_temporary_reorder_cache
(
...
@@ -3747,6 +3772,7 @@ class GenerationMixin:
...
@@ -3747,6 +3772,7 @@ class GenerationMixin:
device
=
input_ids
.
device
device
=
input_ids
.
device
batch_beam_size
,
cur_len
=
input_ids
.
shape
batch_beam_size
,
cur_len
=
input_ids
.
shape
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
if
return_dict_in_generate
and
output_scores
:
if
return_dict_in_generate
and
output_scores
:
beam_indices
=
[
tuple
(()
for
_
in
range
(
num_sub_beams
*
batch_size
))
for
_
in
range
(
num_beam_groups
)]
beam_indices
=
[
tuple
(()
for
_
in
range
(
num_sub_beams
*
batch_size
))
for
_
in
range
(
num_beam_groups
)]
...
@@ -3916,7 +3942,9 @@ class GenerationMixin:
...
@@ -3916,7 +3942,9 @@ class GenerationMixin:
input_ids
=
torch
.
cat
([
input_ids
,
current_tokens
.
unsqueeze
(
-
1
)],
dim
=-
1
)
input_ids
=
torch
.
cat
([
input_ids
,
current_tokens
.
unsqueeze
(
-
1
)],
dim
=-
1
)
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_inputs
=
model_inputs
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
)
if
model_kwargs
.
get
(
"past_key_values"
,
None
)
is
not
None
:
if
model_kwargs
.
get
(
"past_key_values"
,
None
)
is
not
None
:
model_kwargs
[
"past_key_values"
]
=
self
.
_temporary_reorder_cache
(
model_kwargs
[
"past_key_values"
]
=
self
.
_temporary_reorder_cache
(
...
@@ -4155,6 +4183,7 @@ class GenerationMixin:
...
@@ -4155,6 +4183,7 @@ class GenerationMixin:
num_beams
=
constrained_beam_scorer
.
num_beams
num_beams
=
constrained_beam_scorer
.
num_beams
batch_beam_size
,
cur_len
=
input_ids
.
shape
batch_beam_size
,
cur_len
=
input_ids
.
shape
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
if
num_beams
*
batch_size
!=
batch_beam_size
:
if
num_beams
*
batch_size
!=
batch_beam_size
:
raise
ValueError
(
raise
ValueError
(
...
@@ -4275,7 +4304,9 @@ class GenerationMixin:
...
@@ -4275,7 +4304,9 @@ class GenerationMixin:
input_ids
=
torch
.
cat
([
input_ids
[
beam_idx
,
:],
beam_next_tokens
.
unsqueeze
(
-
1
)],
dim
=-
1
)
input_ids
=
torch
.
cat
([
input_ids
[
beam_idx
,
:],
beam_next_tokens
.
unsqueeze
(
-
1
)],
dim
=-
1
)
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_inputs
=
model_inputs
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
)
if
model_kwargs
.
get
(
"past_key_values"
,
None
)
is
not
None
:
if
model_kwargs
.
get
(
"past_key_values"
,
None
)
is
not
None
:
model_kwargs
[
"past_key_values"
]
=
self
.
_temporary_reorder_cache
(
model_kwargs
[
"past_key_values"
]
=
self
.
_temporary_reorder_cache
(
...
@@ -4511,7 +4542,13 @@ class GenerationMixin:
...
@@ -4511,7 +4542,13 @@ class GenerationMixin:
)
)
# keep track of which sequences are already finished
# keep track of which sequences are already finished
unfinished_sequences
=
input_ids
.
new
(
input_ids
.
shape
[
0
]).
fill_
(
1
)
batch_size
,
cur_len
=
batch_size
,
cur_len
=
(
model_kwargs
[
"attention_mask"
].
shape
if
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
not
None
else
input_ids
.
shape
)
unfinished_sequences
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
model_kwargs
[
"cache_position"
]
=
torch
.
arange
(
cur_len
,
device
=
input_ids
.
device
)
# other auxiliary variables
# other auxiliary variables
max_len
=
stopping_criteria
[
0
].
max_length
max_len
=
stopping_criteria
[
0
].
max_length
...
@@ -4555,6 +4592,14 @@ class GenerationMixin:
...
@@ -4555,6 +4592,14 @@ class GenerationMixin:
candidate_kwargs
,
candidate_input_ids
.
shape
[
1
],
self
.
config
.
is_encoder_decoder
candidate_kwargs
,
candidate_input_ids
.
shape
[
1
],
self
.
config
.
is_encoder_decoder
)
)
candidate_kwargs
=
_prepare_token_type_ids
(
candidate_kwargs
,
candidate_input_ids
.
shape
[
1
])
candidate_kwargs
=
_prepare_token_type_ids
(
candidate_kwargs
,
candidate_input_ids
.
shape
[
1
])
if
"cache_position"
in
candidate_kwargs
:
candidate_kwargs
[
"cache_position"
]
=
torch
.
cat
(
(
candidate_kwargs
[
"cache_position"
],
torch
.
arange
(
cur_len
,
cur_len
+
candidate_length
,
device
=
input_ids
.
device
,
dtype
=
torch
.
long
),
),
dim
=
0
,
)
model_inputs
=
self
.
prepare_inputs_for_generation
(
candidate_input_ids
,
**
candidate_kwargs
)
model_inputs
=
self
.
prepare_inputs_for_generation
(
candidate_input_ids
,
**
candidate_kwargs
)
...
@@ -4673,7 +4718,9 @@ class GenerationMixin:
...
@@ -4673,7 +4718,9 @@ class GenerationMixin:
)
)
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
model_kwargs
=
self
.
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
model_inputs
=
model_inputs
outputs
,
model_kwargs
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
)
)
# if eos_token was found in one sentence, set sentence to finished
# if eos_token was found in one sentence, set sentence to finished
...
...
src/transformers/models/gemma/modeling_gemma.py
View file @
23db187d
...
@@ -256,7 +256,7 @@ class GemmaAttention(nn.Module):
...
@@ -256,7 +256,7 @@ class GemmaAttention(nn.Module):
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
None
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
None
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; position
_ids
needed for the static cache
# sin and cos are specific to RoPE models;
cache_
position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
...
@@ -343,7 +343,7 @@ class GemmaFlashAttention2(GemmaAttention):
...
@@ -343,7 +343,7 @@ class GemmaFlashAttention2(GemmaAttention):
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; position
_ids
needed for the static cache
# sin and cos are specific to RoPE models;
cache_
position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
...
@@ -542,7 +542,7 @@ class GemmaSdpaAttention(GemmaAttention):
...
@@ -542,7 +542,7 @@ class GemmaSdpaAttention(GemmaAttention):
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; position
_ids
needed for the static cache
# sin and cos are specific to RoPE models;
cache_
position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
...
@@ -791,6 +791,10 @@ GEMMA_INPUTS_DOCSTRING = r"""
...
@@ -791,6 +791,10 @@ GEMMA_INPUTS_DOCSTRING = r"""
more detail.
more detail.
return_dict (`bool`, *optional*):
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
"""
...
@@ -1128,14 +1132,26 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
...
@@ -1128,14 +1132,26 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
)
)
def
prepare_inputs_for_generation
(
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past_key_values
=
None
,
attention_mask
=
None
,
inputs_embeds
=
None
,
**
kwargs
self
,
input_ids
,
past_key_values
=
None
,
attention_mask
=
None
,
inputs_embeds
=
None
,
cache_position
=
None
,
**
kwargs
):
):
# With static cache, the `past_key_values` is None
# TODO joao: standardize interface for the different Cache classes and remove of this if
has_static_cache
=
False
if
past_key_values
is
None
:
past_key_values
=
getattr
(
self
.
model
.
layers
[
0
].
self_attn
,
"past_key_value"
,
None
)
has_static_cache
=
past_key_values
is
not
None
past_length
=
0
past_length
=
0
if
past_key_values
is
not
None
:
if
past_key_values
is
not
None
:
if
isinstance
(
past_key_values
,
Cache
):
if
isinstance
(
past_key_values
,
Cache
):
cache_length
=
past_key_values
.
get_seq_length
()
past_length
=
cache_position
[
0
]
if
cache_position
is
not
None
else
past_key_values
.
get_seq_length
()
past_length
=
past_key_values
.
seen_tokens
max_cache_length
=
(
max_cache_length
=
past_key_values
.
get_max_length
()
torch
.
tensor
(
past_key_values
.
get_max_length
(),
device
=
input_ids
.
device
)
if
past_key_values
.
get_max_length
()
is
not
None
else
None
)
cache_length
=
past_length
if
max_cache_length
is
None
else
torch
.
min
(
max_cache_length
,
past_length
)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else
:
else
:
cache_length
=
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
cache_length
=
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
max_cache_length
=
None
max_cache_length
=
None
...
@@ -1168,22 +1184,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
...
@@ -1168,22 +1184,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
if
past_key_values
:
if
past_key_values
:
position_ids
=
position_ids
[:,
-
input_ids
.
shape
[
1
]
:]
position_ids
=
position_ids
[:,
-
input_ids
.
shape
[
1
]
:]
if
self
.
generation_config
.
cache_implementation
==
"static"
:
# generation with static cache
cache_position
=
kwargs
.
get
(
"cache_position"
,
None
)
if
cache_position
is
None
:
past_length
=
0
else
:
past_length
=
cache_position
[
-
1
]
+
1
input_ids
=
input_ids
[:,
past_length
:]
position_ids
=
position_ids
[:,
past_length
:]
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
input_length
=
position_ids
.
shape
[
-
1
]
if
position_ids
is
not
None
else
input_ids
.
shape
[
-
1
]
cache_position
=
torch
.
arange
(
past_length
,
past_length
+
input_length
,
device
=
input_ids
.
device
)
position_ids
=
position_ids
.
contiguous
()
if
position_ids
is
not
None
else
None
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if
inputs_embeds
is
not
None
and
past_key_values
is
None
:
if
inputs_embeds
is
not
None
and
past_key_values
is
None
:
model_inputs
=
{
"inputs_embeds"
:
inputs_embeds
}
model_inputs
=
{
"inputs_embeds"
:
inputs_embeds
}
...
@@ -1193,6 +1193,15 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
...
@@ -1193,6 +1193,15 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
# TODO: use `next_tokens` directly instead.
# TODO: use `next_tokens` directly instead.
model_inputs
=
{
"input_ids"
:
input_ids
.
contiguous
()}
model_inputs
=
{
"input_ids"
:
input_ids
.
contiguous
()}
input_length
=
position_ids
.
shape
[
-
1
]
if
position_ids
is
not
None
else
input_ids
.
shape
[
-
1
]
if
cache_position
is
None
:
cache_position
=
torch
.
arange
(
past_length
,
past_length
+
input_length
,
device
=
input_ids
.
device
)
else
:
cache_position
=
cache_position
[
-
input_length
:]
if
has_static_cache
:
past_key_values
=
None
model_inputs
.
update
(
model_inputs
.
update
(
{
{
"position_ids"
:
position_ids
,
"position_ids"
:
position_ids
,
...
...
src/transformers/models/idefics/modeling_idefics.py
View file @
23db187d
...
@@ -1557,10 +1557,12 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
...
@@ -1557,10 +1557,12 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
model_kwargs
:
Dict
[
str
,
Any
],
model_kwargs
:
Dict
[
str
,
Any
],
is_encoder_decoder
:
bool
=
False
,
is_encoder_decoder
:
bool
=
False
,
standardize_cache_format
:
bool
=
False
,
standardize_cache_format
:
bool
=
False
,
model_inputs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
Dict
[
str
,
Any
]:
)
->
Dict
[
str
,
Any
]:
model_kwargs
=
super
().
_update_model_kwargs_for_generation
(
model_kwargs
=
super
().
_update_model_kwargs_for_generation
(
outputs
,
model_kwargs
,
is_encoder_decoder
,
standardize_cache_format
,
model_inputs
outputs
,
model_kwargs
,
is_encoder_decoder
,
standardize_cache_format
,
)
)
if
"image_attention_mask"
in
model_kwargs
:
if
"image_attention_mask"
in
model_kwargs
:
...
...
src/transformers/models/llama/modeling_llama.py
View file @
23db187d
...
@@ -361,7 +361,7 @@ class LlamaAttention(nn.Module):
...
@@ -361,7 +361,7 @@ class LlamaAttention(nn.Module):
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; position
_ids
needed for the static cache
# sin and cos are specific to RoPE models;
cache_
position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
...
@@ -451,7 +451,7 @@ class LlamaFlashAttention2(LlamaAttention):
...
@@ -451,7 +451,7 @@ class LlamaFlashAttention2(LlamaAttention):
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; position
_ids
needed for the static cache
# sin and cos are specific to RoPE models;
cache_
position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
...
@@ -650,7 +650,7 @@ class LlamaSdpaAttention(LlamaAttention):
...
@@ -650,7 +650,7 @@ class LlamaSdpaAttention(LlamaAttention):
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
past_key_value
=
getattr
(
self
,
"past_key_value"
,
past_key_value
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
# sin and cos are specific to RoPE models; position
_ids
needed for the static cache
# sin and cos are specific to RoPE models;
cache_
position needed for the static cache
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
cache_kwargs
=
{
"sin"
:
sin
,
"cos"
:
cos
,
"cache_position"
:
cache_position
}
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
key_states
,
value_states
=
past_key_value
.
update
(
key_states
,
value_states
,
self
.
layer_idx
,
cache_kwargs
)
...
@@ -903,6 +903,10 @@ LLAMA_INPUTS_DOCSTRING = r"""
...
@@ -903,6 +903,10 @@ LLAMA_INPUTS_DOCSTRING = r"""
more detail.
more detail.
return_dict (`bool`, *optional*):
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
"""
...
@@ -1240,14 +1244,26 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
...
@@ -1240,14 +1244,26 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
)
)
def
prepare_inputs_for_generation
(
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past_key_values
=
None
,
attention_mask
=
None
,
inputs_embeds
=
None
,
**
kwargs
self
,
input_ids
,
past_key_values
=
None
,
attention_mask
=
None
,
inputs_embeds
=
None
,
cache_position
=
None
,
**
kwargs
):
):
# With static cache, the `past_key_values` is None
# TODO joao: standardize interface for the different Cache classes and remove of this if
has_static_cache
=
False
if
past_key_values
is
None
:
past_key_values
=
getattr
(
self
.
model
.
layers
[
0
].
self_attn
,
"past_key_value"
,
None
)
has_static_cache
=
past_key_values
is
not
None
past_length
=
0
past_length
=
0
if
past_key_values
is
not
None
:
if
past_key_values
is
not
None
:
if
isinstance
(
past_key_values
,
Cache
):
if
isinstance
(
past_key_values
,
Cache
):
cache_length
=
past_key_values
.
get_seq_length
()
past_length
=
cache_position
[
0
]
if
cache_position
is
not
None
else
past_key_values
.
get_seq_length
()
past_length
=
past_key_values
.
seen_tokens
max_cache_length
=
(
max_cache_length
=
past_key_values
.
get_max_length
()
torch
.
tensor
(
past_key_values
.
get_max_length
(),
device
=
input_ids
.
device
)
if
past_key_values
.
get_max_length
()
is
not
None
else
None
)
cache_length
=
past_length
if
max_cache_length
is
None
else
torch
.
min
(
max_cache_length
,
past_length
)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else
:
else
:
cache_length
=
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
cache_length
=
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
max_cache_length
=
None
max_cache_length
=
None
...
@@ -1280,22 +1296,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
...
@@ -1280,22 +1296,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
if
past_key_values
:
if
past_key_values
:
position_ids
=
position_ids
[:,
-
input_ids
.
shape
[
1
]
:]
position_ids
=
position_ids
[:,
-
input_ids
.
shape
[
1
]
:]
if
self
.
generation_config
.
cache_implementation
==
"static"
:
# generation with static cache
cache_position
=
kwargs
.
get
(
"cache_position"
,
None
)
if
cache_position
is
None
:
past_length
=
0
else
:
past_length
=
cache_position
[
-
1
]
+
1
input_ids
=
input_ids
[:,
past_length
:]
position_ids
=
position_ids
[:,
past_length
:]
# TODO @gante we should only keep a `cache_position` in generate, and do +=1.
# same goes for position ids. Could also help with continued generation.
input_length
=
position_ids
.
shape
[
-
1
]
if
position_ids
is
not
None
else
input_ids
.
shape
[
-
1
]
cache_position
=
torch
.
arange
(
past_length
,
past_length
+
input_length
,
device
=
input_ids
.
device
)
position_ids
=
position_ids
.
contiguous
()
if
position_ids
is
not
None
else
None
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if
inputs_embeds
is
not
None
and
past_key_values
is
None
:
if
inputs_embeds
is
not
None
and
past_key_values
is
None
:
model_inputs
=
{
"inputs_embeds"
:
inputs_embeds
}
model_inputs
=
{
"inputs_embeds"
:
inputs_embeds
}
...
@@ -1305,6 +1305,15 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
...
@@ -1305,6 +1305,15 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
# TODO: use `next_tokens` directly instead.
# TODO: use `next_tokens` directly instead.
model_inputs
=
{
"input_ids"
:
input_ids
.
contiguous
()}
model_inputs
=
{
"input_ids"
:
input_ids
.
contiguous
()}
input_length
=
position_ids
.
shape
[
-
1
]
if
position_ids
is
not
None
else
input_ids
.
shape
[
-
1
]
if
cache_position
is
None
:
cache_position
=
torch
.
arange
(
past_length
,
past_length
+
input_length
,
device
=
input_ids
.
device
)
else
:
cache_position
=
cache_position
[
-
input_length
:]
if
has_static_cache
:
past_key_values
=
None
model_inputs
.
update
(
model_inputs
.
update
(
{
{
"position_ids"
:
position_ids
,
"position_ids"
:
position_ids
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment