Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ce0bbd51
"tests/vscode:/vscode.git/clone" did not exist on "ce2fef2ad278cd72748dfe5b049b4d58569e3d9a"
Unverified
Commit
ce0bbd51
authored
Dec 08, 2023
by
Joao Gante
Committed by
GitHub
Dec 08, 2023
Browse files
Generate: SinkCache can handle iterative prompts (#27907)
parent
94c76538
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
116 additions
and
34 deletions
+116
-34
src/transformers/cache_utils.py
src/transformers/cache_utils.py
+27
-3
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+12
-7
src/transformers/models/mistral/modeling_mistral.py
src/transformers/models/mistral/modeling_mistral.py
+12
-9
src/transformers/models/persimmon/modeling_persimmon.py
src/transformers/models/persimmon/modeling_persimmon.py
+11
-6
src/transformers/models/phi/modeling_phi.py
src/transformers/models/phi/modeling_phi.py
+12
-9
tests/test_cache_utils.py
tests/test_cache_utils.py
+42
-0
No files found.
src/transformers/cache_utils.py
View file @
ce0bbd51
...
...
@@ -38,6 +38,21 @@ class Cache:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
raise
NotImplementedError
(
"Make sure to implement `get_seq_length` in a subclass."
)
def
get_max_length
(
self
)
->
Optional
[
int
]:
"""Returns the maximum sequence length of the cached states, if there is any."""
raise
NotImplementedError
(
"Make sure to implement `get_max_length` in a subclass."
)
def
get_usable_length
(
self
,
new_seq_length
:
int
,
layer_idx
:
Optional
[
int
]
=
0
)
->
int
:
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
# Cache without size limit -> all cache is usable
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
# length, we will need to evict part of the cache (and thus not all cache is usable)
max_length
=
self
.
get_max_length
()
previous_seq_length
=
self
.
get_seq_length
(
layer_idx
)
if
max_length
is
not
None
and
previous_seq_length
+
new_seq_length
>
max_length
:
return
max_length
-
new_seq_length
return
previous_seq_length
class
DynamicCache
(
Cache
):
"""
...
...
@@ -120,6 +135,10 @@ class DynamicCache(Cache):
return
0
return
self
.
key_cache
[
layer_idx
].
shape
[
-
2
]
def
get_max_length
(
self
)
->
Optional
[
int
]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
return
None
def
reorder_cache
(
self
,
beam_idx
:
torch
.
LongTensor
):
"""Reorders the cache for beam search, given the selected beam indices."""
for
layer_idx
in
range
(
len
(
self
.
key_cache
)):
...
...
@@ -209,8 +228,11 @@ class SinkCache(Cache):
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
if
len
(
self
.
key_cache
)
<=
layer_idx
:
return
0
cache_length
=
self
.
key_cache
[
layer_idx
].
shape
[
-
2
]
return
min
(
cache_length
,
self
.
window_length
-
1
)
return
self
.
key_cache
[
layer_idx
].
shape
[
-
2
]
def
get_max_length
(
self
)
->
Optional
[
int
]:
"""Returns the maximum sequence length of the cached states."""
return
self
.
window_length
def
update
(
self
,
...
...
@@ -267,7 +289,9 @@ class SinkCache(Cache):
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
if
using_rope
:
rerotation_cos
,
rerotation_sin
=
self
.
_get_rerotation_cos_sin
(
key_states
,
cos
,
sin
)
rerotation_cos
,
rerotation_sin
=
self
.
_get_rerotation_cos_sin
(
key_states
,
cos
[:
self
.
window_length
],
sin
[:
self
.
window_length
]
)
if
partial_rotation_size
is
not
None
:
keys_to_keep
,
keys_pass
=
(
keys_to_keep
[...,
:
partial_rotation_size
],
...
...
src/transformers/models/llama/modeling_llama.py
View file @
ce0bbd51
...
...
@@ -398,7 +398,7 @@ class LlamaAttention(nn.Module):
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len
+=
past_key_value
.
get_
seq
_length
(
self
.
layer_idx
)
kv_seq_len
+=
past_key_value
.
get_
usable
_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
...
...
@@ -503,7 +503,7 @@ class LlamaFlashAttention2(LlamaAttention):
kv_seq_len
=
key_states
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
.
get_
seq
_length
(
self
.
layer_idx
)
kv_seq_len
+=
past_key_value
.
get_
usable
_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
...
...
@@ -910,7 +910,7 @@ class LlamaModel(LlamaPreTrainedModel):
use_legacy_cache
=
not
isinstance
(
past_key_values
,
Cache
)
if
use_legacy_cache
:
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
past_key_values_length
=
past_key_values
.
get_seq_length
(
)
past_key_values_length
=
past_key_values
.
get_
usable_length
(
seq_length
)
if
position_ids
is
None
:
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
...
...
@@ -1127,8 +1127,10 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
if
isinstance
(
past_key_values
,
Cache
):
cache_length
=
past_key_values
.
get_seq_length
()
past_length
=
past_key_values
.
seen_tokens
max_cache_length
=
past_key_values
.
get_max_length
()
else
:
cache_length
=
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
max_cache_length
=
None
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...
...
@@ -1142,10 +1144,13 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
input_ids
=
input_ids
[:,
past_length
:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# older attention values, as their corresponding values are not part of the input.
if
cache_length
<
past_length
and
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
-
(
cache_length
+
input_ids
.
shape
[
1
])
:]
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if
(
max_cache_length
is
not
None
and
attention_mask
is
not
None
and
cache_length
+
input_ids
.
shape
[
1
]
>
max_cache_length
):
attention_mask
=
attention_mask
[:,
-
max_cache_length
:]
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
if
attention_mask
is
not
None
and
position_ids
is
None
:
...
...
src/transformers/models/mistral/modeling_mistral.py
View file @
ce0bbd51
...
...
@@ -268,7 +268,7 @@ class MistralAttention(nn.Module):
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len
+=
past_key_value
.
get_
seq
_length
(
self
.
layer_idx
)
kv_seq_len
+=
past_key_value
.
get_
usable
_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
...
...
@@ -363,7 +363,7 @@ class MistralFlashAttention2(MistralAttention):
kv_seq_len
=
key_states
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
.
get_
seq
_length
(
self
.
layer_idx
)
kv_seq_len
+=
past_key_value
.
get_
usable
_length
(
kv_seq_len
,
self
.
layer_idx
)
# Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len
=
max
(
kv_seq_len
,
position_ids
[:,
-
1
].
max
().
item
())
+
1
...
...
@@ -850,15 +850,13 @@ class MistralModel(MistralPreTrainedModel):
else
:
raise
ValueError
(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past
=
seq_length
past_key_values_length
=
0
if
use_cache
:
use_legacy_cache
=
not
isinstance
(
past_key_values
,
Cache
)
if
use_legacy_cache
:
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
past_key_values_length
=
past_key_values
.
get_seq_length
()
seq_length_with_past
=
seq_length_with_past
+
past_key_values_length
past_key_values_length
=
past_key_values
.
get_usable_length
(
seq_length
)
if
position_ids
is
None
:
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
...
...
@@ -1092,8 +1090,10 @@ class MistralForCausalLM(MistralPreTrainedModel):
if
isinstance
(
past_key_values
,
Cache
):
cache_length
=
past_key_values
.
get_seq_length
()
past_length
=
past_key_values
.
seen_tokens
max_cache_length
=
past_key_values
.
get_max_length
()
else
:
cache_length
=
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
max_cache_length
=
None
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...
...
@@ -1107,10 +1107,13 @@ class MistralForCausalLM(MistralPreTrainedModel):
input_ids
=
input_ids
[:,
past_length
:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# older attention values, as their corresponding values are not part of the input.
if
cache_length
<
past_length
and
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
-
(
cache_length
+
input_ids
.
shape
[
1
])
:]
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if
(
max_cache_length
is
not
None
and
attention_mask
is
not
None
and
cache_length
+
input_ids
.
shape
[
1
]
>
max_cache_length
):
attention_mask
=
attention_mask
[:,
-
max_cache_length
:]
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
if
attention_mask
is
not
None
and
position_ids
is
None
:
...
...
src/transformers/models/persimmon/modeling_persimmon.py
View file @
ce0bbd51
...
...
@@ -295,7 +295,7 @@ class PersimmonAttention(nn.Module):
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len
+=
past_key_value
.
get_
seq
_length
(
self
.
layer_idx
)
kv_seq_len
+=
past_key_value
.
get_
usable
_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
# Partial rotary embedding
...
...
@@ -612,7 +612,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
use_legacy_cache
=
not
isinstance
(
past_key_values
,
Cache
)
if
use_legacy_cache
:
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
past_key_values_length
=
past_key_values
.
get_seq_length
(
)
past_key_values_length
=
past_key_values
.
get_
usable_length
(
seq_length
)
seq_length_with_past
=
seq_length_with_past
+
past_key_values_length
if
position_ids
is
None
:
...
...
@@ -831,8 +831,10 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
if
isinstance
(
past_key_values
,
Cache
):
cache_length
=
past_key_values
.
get_seq_length
()
past_length
=
past_key_values
.
seen_tokens
max_cache_length
=
past_key_values
.
get_max_length
()
else
:
cache_length
=
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
max_cache_length
=
None
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...
...
@@ -846,10 +848,13 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
input_ids
=
input_ids
[:,
past_length
:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# older attention values, as their corresponding values are not part of the input.
if
cache_length
<
past_length
and
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
-
(
cache_length
+
input_ids
.
shape
[
1
])
:]
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if
(
max_cache_length
is
not
None
and
attention_mask
is
not
None
and
cache_length
+
input_ids
.
shape
[
1
]
>
max_cache_length
):
attention_mask
=
attention_mask
[:,
-
max_cache_length
:]
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
if
attention_mask
is
not
None
and
position_ids
is
None
:
...
...
src/transformers/models/phi/modeling_phi.py
View file @
ce0bbd51
...
...
@@ -334,7 +334,7 @@ class PhiAttention(nn.Module):
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len
+=
past_key_value
.
get_
seq
_length
(
self
.
layer_idx
)
kv_seq_len
+=
past_key_value
.
get_
usable
_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
# Partial rotary embedding
...
...
@@ -444,7 +444,7 @@ class PhiFlashAttention2(PhiAttention):
kv_seq_len
=
key_states
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
.
get_
seq
_length
(
self
.
layer_idx
)
kv_seq_len
+=
past_key_value
.
get_
usable
_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
# Partial rotary embedding
...
...
@@ -855,15 +855,13 @@ class PhiModel(PhiPreTrainedModel):
else
:
raise
ValueError
(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past
=
seq_length
past_key_values_length
=
0
if
use_cache
:
use_legacy_cache
=
not
isinstance
(
past_key_values
,
Cache
)
if
use_legacy_cache
:
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
past_key_values_length
=
past_key_values
.
get_seq_length
()
seq_length_with_past
=
seq_length_with_past
+
past_key_values_length
past_key_values_length
=
past_key_values
.
get_usable_length
(
seq_length
)
if
position_ids
is
None
:
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
...
...
@@ -1085,8 +1083,10 @@ class PhiForCausalLM(PhiPreTrainedModel):
if
isinstance
(
past_key_values
,
Cache
):
cache_length
=
past_key_values
.
get_seq_length
()
past_length
=
past_key_values
.
seen_tokens
max_cache_length
=
past_key_values
.
get_max_length
()
else
:
cache_length
=
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
max_cache_length
=
None
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...
...
@@ -1100,10 +1100,13 @@ class PhiForCausalLM(PhiPreTrainedModel):
input_ids
=
input_ids
[:,
past_length
:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# older attention values, as their corresponding values are not part of the input.
if
cache_length
<
past_length
and
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
-
(
cache_length
+
input_ids
.
shape
[
1
])
:]
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if
(
max_cache_length
is
not
None
and
attention_mask
is
not
None
and
cache_length
+
input_ids
.
shape
[
1
]
>
max_cache_length
):
attention_mask
=
attention_mask
[:,
-
max_cache_length
:]
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
if
attention_mask
is
not
None
and
position_ids
is
None
:
...
...
tests/test_cache_utils.py
View file @
ce0bbd51
...
...
@@ -187,3 +187,45 @@ class CacheIntegrationTest(unittest.TestCase):
gen_out
=
model
.
generate
(
**
inputs
,
do_sample
=
False
,
max_new_tokens
=
3000
,
past_key_values
=
cache
)
decoded
=
tokenizer
.
batch_decode
(
gen_out
,
skip_special_tokens
=
True
)
self
.
assertTrue
(
decoded
[
0
].
endswith
(
"to perform a variety of tasks. The Transformer is a neural network"
))
def
test_sink_cache_iterative_prompts
(
self
):
"""Tests that SinkCache supports more than one new token at once, when shifting the cache"""
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"HuggingFaceH4/zephyr-7b-beta"
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
"HuggingFaceH4/zephyr-7b-beta"
,
device_map
=
"auto"
,
torch_dtype
=
torch
.
float16
)
prompt
=
(
"Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences "
"and must-see attractions."
)
# Prepare generation settings
cache
=
SinkCache
(
window_length
=
256
,
num_sink_tokens
=
4
)
input_ids
=
torch
.
tensor
([],
device
=
model
.
device
,
dtype
=
torch
.
int
)
for
_
in
range
(
3
):
# Tokenize the prompt with the correct chat template
chat
=
[{
"role"
:
"user"
,
"content"
:
prompt
}]
tokenized_chat
=
tokenizer
.
apply_chat_template
(
chat
,
return_tensors
=
"pt"
,
add_generation_prompt
=
True
).
to
(
model
.
device
)
input_ids
=
torch
.
cat
((
input_ids
,
tokenized_chat
),
dim
=
1
)
# Perform the generation
gen_out
=
model
.
generate
(
input_ids
,
do_sample
=
False
,
max_new_tokens
=
100
,
past_key_values
=
cache
,
use_cache
=
True
)
input_ids
=
gen_out
# We went well beyond the cache length
self
.
assertTrue
(
input_ids
.
shape
[
1
]
>
cache
.
get_max_length
()
*
1.5
)
# And it still produces a coherent english
decoded
=
tokenizer
.
batch_decode
(
input_ids
,
skip_special_tokens
=
True
)
last_output
=
(
"<|assistant|>
\n
As the sun began to set over the Pacific Ocean, I found myself standing on the shores of "
"Waikiki Beach, my heart filled with awe and wonder. I had just returned from a two-week journey to the "
"beautiful island of Hawaii, and it had been an unforgettable experience filled with cultural experiences "
"and must-see attractions that left me breathless.
\n\n
One of the most memorable experiences of my trip "
"was visiting the historic district of Honolulu. Here,"
)
self
.
assertTrue
(
decoded
[
0
].
endswith
(
last_output
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment