Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ce0bbd51
Unverified
Commit
ce0bbd51
authored
Dec 08, 2023
by
Joao Gante
Committed by
GitHub
Dec 08, 2023
Browse files
Generate: SinkCache can handle iterative prompts (#27907)
parent
94c76538
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
116 additions
and
34 deletions
+116
-34
src/transformers/cache_utils.py
src/transformers/cache_utils.py
+27
-3
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+12
-7
src/transformers/models/mistral/modeling_mistral.py
src/transformers/models/mistral/modeling_mistral.py
+12
-9
src/transformers/models/persimmon/modeling_persimmon.py
src/transformers/models/persimmon/modeling_persimmon.py
+11
-6
src/transformers/models/phi/modeling_phi.py
src/transformers/models/phi/modeling_phi.py
+12
-9
tests/test_cache_utils.py
tests/test_cache_utils.py
+42
-0
No files found.
src/transformers/cache_utils.py
View file @
ce0bbd51
...
@@ -38,6 +38,21 @@ class Cache:
...
@@ -38,6 +38,21 @@ class Cache:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
raise
NotImplementedError
(
"Make sure to implement `get_seq_length` in a subclass."
)
raise
NotImplementedError
(
"Make sure to implement `get_seq_length` in a subclass."
)
def
get_max_length
(
self
)
->
Optional
[
int
]:
"""Returns the maximum sequence length of the cached states, if there is any."""
raise
NotImplementedError
(
"Make sure to implement `get_max_length` in a subclass."
)
def
get_usable_length
(
self
,
new_seq_length
:
int
,
layer_idx
:
Optional
[
int
]
=
0
)
->
int
:
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
# Cache without size limit -> all cache is usable
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
# length, we will need to evict part of the cache (and thus not all cache is usable)
max_length
=
self
.
get_max_length
()
previous_seq_length
=
self
.
get_seq_length
(
layer_idx
)
if
max_length
is
not
None
and
previous_seq_length
+
new_seq_length
>
max_length
:
return
max_length
-
new_seq_length
return
previous_seq_length
class
DynamicCache
(
Cache
):
class
DynamicCache
(
Cache
):
"""
"""
...
@@ -120,6 +135,10 @@ class DynamicCache(Cache):
...
@@ -120,6 +135,10 @@ class DynamicCache(Cache):
return
0
return
0
return
self
.
key_cache
[
layer_idx
].
shape
[
-
2
]
return
self
.
key_cache
[
layer_idx
].
shape
[
-
2
]
def
get_max_length
(
self
)
->
Optional
[
int
]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
return
None
def
reorder_cache
(
self
,
beam_idx
:
torch
.
LongTensor
):
def
reorder_cache
(
self
,
beam_idx
:
torch
.
LongTensor
):
"""Reorders the cache for beam search, given the selected beam indices."""
"""Reorders the cache for beam search, given the selected beam indices."""
for
layer_idx
in
range
(
len
(
self
.
key_cache
)):
for
layer_idx
in
range
(
len
(
self
.
key_cache
)):
...
@@ -209,8 +228,11 @@ class SinkCache(Cache):
...
@@ -209,8 +228,11 @@ class SinkCache(Cache):
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
if
len
(
self
.
key_cache
)
<=
layer_idx
:
if
len
(
self
.
key_cache
)
<=
layer_idx
:
return
0
return
0
cache_length
=
self
.
key_cache
[
layer_idx
].
shape
[
-
2
]
return
self
.
key_cache
[
layer_idx
].
shape
[
-
2
]
return
min
(
cache_length
,
self
.
window_length
-
1
)
def
get_max_length
(
self
)
->
Optional
[
int
]:
"""Returns the maximum sequence length of the cached states."""
return
self
.
window_length
def
update
(
def
update
(
self
,
self
,
...
@@ -267,7 +289,9 @@ class SinkCache(Cache):
...
@@ -267,7 +289,9 @@ class SinkCache(Cache):
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
if
using_rope
:
if
using_rope
:
rerotation_cos
,
rerotation_sin
=
self
.
_get_rerotation_cos_sin
(
key_states
,
cos
,
sin
)
rerotation_cos
,
rerotation_sin
=
self
.
_get_rerotation_cos_sin
(
key_states
,
cos
[:
self
.
window_length
],
sin
[:
self
.
window_length
]
)
if
partial_rotation_size
is
not
None
:
if
partial_rotation_size
is
not
None
:
keys_to_keep
,
keys_pass
=
(
keys_to_keep
,
keys_pass
=
(
keys_to_keep
[...,
:
partial_rotation_size
],
keys_to_keep
[...,
:
partial_rotation_size
],
...
...
src/transformers/models/llama/modeling_llama.py
View file @
ce0bbd51
...
@@ -398,7 +398,7 @@ class LlamaAttention(nn.Module):
...
@@ -398,7 +398,7 @@ class LlamaAttention(nn.Module):
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
"with a layer index."
)
)
kv_seq_len
+=
past_key_value
.
get_
seq
_length
(
self
.
layer_idx
)
kv_seq_len
+=
past_key_value
.
get_
usable
_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
...
@@ -503,7 +503,7 @@ class LlamaFlashAttention2(LlamaAttention):
...
@@ -503,7 +503,7 @@ class LlamaFlashAttention2(LlamaAttention):
kv_seq_len
=
key_states
.
shape
[
-
2
]
kv_seq_len
=
key_states
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
.
get_
seq
_length
(
self
.
layer_idx
)
kv_seq_len
+=
past_key_value
.
get_
usable
_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
...
@@ -910,7 +910,7 @@ class LlamaModel(LlamaPreTrainedModel):
...
@@ -910,7 +910,7 @@ class LlamaModel(LlamaPreTrainedModel):
use_legacy_cache
=
not
isinstance
(
past_key_values
,
Cache
)
use_legacy_cache
=
not
isinstance
(
past_key_values
,
Cache
)
if
use_legacy_cache
:
if
use_legacy_cache
:
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
past_key_values_length
=
past_key_values
.
get_seq_length
(
)
past_key_values_length
=
past_key_values
.
get_
usable_length
(
seq_length
)
if
position_ids
is
None
:
if
position_ids
is
None
:
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
...
@@ -1127,8 +1127,10 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
...
@@ -1127,8 +1127,10 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
if
isinstance
(
past_key_values
,
Cache
):
if
isinstance
(
past_key_values
,
Cache
):
cache_length
=
past_key_values
.
get_seq_length
()
cache_length
=
past_key_values
.
get_seq_length
()
past_length
=
past_key_values
.
seen_tokens
past_length
=
past_key_values
.
seen_tokens
max_cache_length
=
past_key_values
.
get_max_length
()
else
:
else
:
cache_length
=
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
cache_length
=
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
max_cache_length
=
None
# Keep only the unprocessed tokens:
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...
@@ -1142,10 +1144,13 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
...
@@ -1142,10 +1144,13 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
input_ids
=
input_ids
[:,
past_length
:]
input_ids
=
input_ids
[:,
past_length
:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
# older attention values, as their corresponding values are not part of the input.
if
(
if
cache_length
<
past_length
and
attention_mask
is
not
None
:
max_cache_length
is
not
None
attention_mask
=
attention_mask
[:,
-
(
cache_length
+
input_ids
.
shape
[
1
])
:]
and
attention_mask
is
not
None
and
cache_length
+
input_ids
.
shape
[
1
]
>
max_cache_length
):
attention_mask
=
attention_mask
[:,
-
max_cache_length
:]
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
if
attention_mask
is
not
None
and
position_ids
is
None
:
if
attention_mask
is
not
None
and
position_ids
is
None
:
...
...
src/transformers/models/mistral/modeling_mistral.py
View file @
ce0bbd51
...
@@ -268,7 +268,7 @@ class MistralAttention(nn.Module):
...
@@ -268,7 +268,7 @@ class MistralAttention(nn.Module):
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
"with a layer index."
)
)
kv_seq_len
+=
past_key_value
.
get_
seq
_length
(
self
.
layer_idx
)
kv_seq_len
+=
past_key_value
.
get_
usable
_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
...
@@ -363,7 +363,7 @@ class MistralFlashAttention2(MistralAttention):
...
@@ -363,7 +363,7 @@ class MistralFlashAttention2(MistralAttention):
kv_seq_len
=
key_states
.
shape
[
-
2
]
kv_seq_len
=
key_states
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
.
get_
seq
_length
(
self
.
layer_idx
)
kv_seq_len
+=
past_key_value
.
get_
usable
_length
(
kv_seq_len
,
self
.
layer_idx
)
# Because the input can be padded, the absolute sequence length depends on the max position id.
# Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len
=
max
(
kv_seq_len
,
position_ids
[:,
-
1
].
max
().
item
())
+
1
rotary_seq_len
=
max
(
kv_seq_len
,
position_ids
[:,
-
1
].
max
().
item
())
+
1
...
@@ -850,15 +850,13 @@ class MistralModel(MistralPreTrainedModel):
...
@@ -850,15 +850,13 @@ class MistralModel(MistralPreTrainedModel):
else
:
else
:
raise
ValueError
(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
raise
ValueError
(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past
=
seq_length
past_key_values_length
=
0
past_key_values_length
=
0
if
use_cache
:
if
use_cache
:
use_legacy_cache
=
not
isinstance
(
past_key_values
,
Cache
)
use_legacy_cache
=
not
isinstance
(
past_key_values
,
Cache
)
if
use_legacy_cache
:
if
use_legacy_cache
:
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
past_key_values_length
=
past_key_values
.
get_seq_length
()
past_key_values_length
=
past_key_values
.
get_usable_length
(
seq_length
)
seq_length_with_past
=
seq_length_with_past
+
past_key_values_length
if
position_ids
is
None
:
if
position_ids
is
None
:
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
...
@@ -1092,8 +1090,10 @@ class MistralForCausalLM(MistralPreTrainedModel):
...
@@ -1092,8 +1090,10 @@ class MistralForCausalLM(MistralPreTrainedModel):
if
isinstance
(
past_key_values
,
Cache
):
if
isinstance
(
past_key_values
,
Cache
):
cache_length
=
past_key_values
.
get_seq_length
()
cache_length
=
past_key_values
.
get_seq_length
()
past_length
=
past_key_values
.
seen_tokens
past_length
=
past_key_values
.
seen_tokens
max_cache_length
=
past_key_values
.
get_max_length
()
else
:
else
:
cache_length
=
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
cache_length
=
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
max_cache_length
=
None
# Keep only the unprocessed tokens:
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...
@@ -1107,10 +1107,13 @@ class MistralForCausalLM(MistralPreTrainedModel):
...
@@ -1107,10 +1107,13 @@ class MistralForCausalLM(MistralPreTrainedModel):
input_ids
=
input_ids
[:,
past_length
:]
input_ids
=
input_ids
[:,
past_length
:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
# older attention values, as their corresponding values are not part of the input.
if
(
if
cache_length
<
past_length
and
attention_mask
is
not
None
:
max_cache_length
is
not
None
attention_mask
=
attention_mask
[:,
-
(
cache_length
+
input_ids
.
shape
[
1
])
:]
and
attention_mask
is
not
None
and
cache_length
+
input_ids
.
shape
[
1
]
>
max_cache_length
):
attention_mask
=
attention_mask
[:,
-
max_cache_length
:]
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
if
attention_mask
is
not
None
and
position_ids
is
None
:
if
attention_mask
is
not
None
and
position_ids
is
None
:
...
...
src/transformers/models/persimmon/modeling_persimmon.py
View file @
ce0bbd51
...
@@ -295,7 +295,7 @@ class PersimmonAttention(nn.Module):
...
@@ -295,7 +295,7 @@ class PersimmonAttention(nn.Module):
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
"with a layer index."
)
)
kv_seq_len
+=
past_key_value
.
get_
seq
_length
(
self
.
layer_idx
)
kv_seq_len
+=
past_key_value
.
get_
usable
_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
# Partial rotary embedding
# Partial rotary embedding
...
@@ -612,7 +612,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
...
@@ -612,7 +612,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
use_legacy_cache
=
not
isinstance
(
past_key_values
,
Cache
)
use_legacy_cache
=
not
isinstance
(
past_key_values
,
Cache
)
if
use_legacy_cache
:
if
use_legacy_cache
:
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
past_key_values_length
=
past_key_values
.
get_seq_length
(
)
past_key_values_length
=
past_key_values
.
get_
usable_length
(
seq_length
)
seq_length_with_past
=
seq_length_with_past
+
past_key_values_length
seq_length_with_past
=
seq_length_with_past
+
past_key_values_length
if
position_ids
is
None
:
if
position_ids
is
None
:
...
@@ -831,8 +831,10 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
...
@@ -831,8 +831,10 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
if
isinstance
(
past_key_values
,
Cache
):
if
isinstance
(
past_key_values
,
Cache
):
cache_length
=
past_key_values
.
get_seq_length
()
cache_length
=
past_key_values
.
get_seq_length
()
past_length
=
past_key_values
.
seen_tokens
past_length
=
past_key_values
.
seen_tokens
max_cache_length
=
past_key_values
.
get_max_length
()
else
:
else
:
cache_length
=
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
cache_length
=
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
max_cache_length
=
None
# Keep only the unprocessed tokens:
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...
@@ -846,10 +848,13 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
...
@@ -846,10 +848,13 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
input_ids
=
input_ids
[:,
past_length
:]
input_ids
=
input_ids
[:,
past_length
:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
# older attention values, as their corresponding values are not part of the input.
if
(
if
cache_length
<
past_length
and
attention_mask
is
not
None
:
max_cache_length
is
not
None
attention_mask
=
attention_mask
[:,
-
(
cache_length
+
input_ids
.
shape
[
1
])
:]
and
attention_mask
is
not
None
and
cache_length
+
input_ids
.
shape
[
1
]
>
max_cache_length
):
attention_mask
=
attention_mask
[:,
-
max_cache_length
:]
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
if
attention_mask
is
not
None
and
position_ids
is
None
:
if
attention_mask
is
not
None
and
position_ids
is
None
:
...
...
src/transformers/models/phi/modeling_phi.py
View file @
ce0bbd51
...
@@ -334,7 +334,7 @@ class PhiAttention(nn.Module):
...
@@ -334,7 +334,7 @@ class PhiAttention(nn.Module):
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
"with a layer index."
)
)
kv_seq_len
+=
past_key_value
.
get_
seq
_length
(
self
.
layer_idx
)
kv_seq_len
+=
past_key_value
.
get_
usable
_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
# Partial rotary embedding
# Partial rotary embedding
...
@@ -444,7 +444,7 @@ class PhiFlashAttention2(PhiAttention):
...
@@ -444,7 +444,7 @@ class PhiFlashAttention2(PhiAttention):
kv_seq_len
=
key_states
.
shape
[
-
2
]
kv_seq_len
=
key_states
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
.
get_
seq
_length
(
self
.
layer_idx
)
kv_seq_len
+=
past_key_value
.
get_
usable
_length
(
kv_seq_len
,
self
.
layer_idx
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
# Partial rotary embedding
# Partial rotary embedding
...
@@ -855,15 +855,13 @@ class PhiModel(PhiPreTrainedModel):
...
@@ -855,15 +855,13 @@ class PhiModel(PhiPreTrainedModel):
else
:
else
:
raise
ValueError
(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
raise
ValueError
(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past
=
seq_length
past_key_values_length
=
0
past_key_values_length
=
0
if
use_cache
:
if
use_cache
:
use_legacy_cache
=
not
isinstance
(
past_key_values
,
Cache
)
use_legacy_cache
=
not
isinstance
(
past_key_values
,
Cache
)
if
use_legacy_cache
:
if
use_legacy_cache
:
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
past_key_values_length
=
past_key_values
.
get_seq_length
()
past_key_values_length
=
past_key_values
.
get_usable_length
(
seq_length
)
seq_length_with_past
=
seq_length_with_past
+
past_key_values_length
if
position_ids
is
None
:
if
position_ids
is
None
:
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
...
@@ -1085,8 +1083,10 @@ class PhiForCausalLM(PhiPreTrainedModel):
...
@@ -1085,8 +1083,10 @@ class PhiForCausalLM(PhiPreTrainedModel):
if
isinstance
(
past_key_values
,
Cache
):
if
isinstance
(
past_key_values
,
Cache
):
cache_length
=
past_key_values
.
get_seq_length
()
cache_length
=
past_key_values
.
get_seq_length
()
past_length
=
past_key_values
.
seen_tokens
past_length
=
past_key_values
.
seen_tokens
max_cache_length
=
past_key_values
.
get_max_length
()
else
:
else
:
cache_length
=
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
cache_length
=
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
max_cache_length
=
None
# Keep only the unprocessed tokens:
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...
@@ -1100,10 +1100,13 @@ class PhiForCausalLM(PhiPreTrainedModel):
...
@@ -1100,10 +1100,13 @@ class PhiForCausalLM(PhiPreTrainedModel):
input_ids
=
input_ids
[:,
past_length
:]
input_ids
=
input_ids
[:,
past_length
:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
# older attention values, as their corresponding values are not part of the input.
if
(
if
cache_length
<
past_length
and
attention_mask
is
not
None
:
max_cache_length
is
not
None
attention_mask
=
attention_mask
[:,
-
(
cache_length
+
input_ids
.
shape
[
1
])
:]
and
attention_mask
is
not
None
and
cache_length
+
input_ids
.
shape
[
1
]
>
max_cache_length
):
attention_mask
=
attention_mask
[:,
-
max_cache_length
:]
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
if
attention_mask
is
not
None
and
position_ids
is
None
:
if
attention_mask
is
not
None
and
position_ids
is
None
:
...
...
tests/test_cache_utils.py
View file @
ce0bbd51
...
@@ -187,3 +187,45 @@ class CacheIntegrationTest(unittest.TestCase):
...
@@ -187,3 +187,45 @@ class CacheIntegrationTest(unittest.TestCase):
gen_out
=
model
.
generate
(
**
inputs
,
do_sample
=
False
,
max_new_tokens
=
3000
,
past_key_values
=
cache
)
gen_out
=
model
.
generate
(
**
inputs
,
do_sample
=
False
,
max_new_tokens
=
3000
,
past_key_values
=
cache
)
decoded
=
tokenizer
.
batch_decode
(
gen_out
,
skip_special_tokens
=
True
)
decoded
=
tokenizer
.
batch_decode
(
gen_out
,
skip_special_tokens
=
True
)
self
.
assertTrue
(
decoded
[
0
].
endswith
(
"to perform a variety of tasks. The Transformer is a neural network"
))
self
.
assertTrue
(
decoded
[
0
].
endswith
(
"to perform a variety of tasks. The Transformer is a neural network"
))
def
test_sink_cache_iterative_prompts
(
self
):
"""Tests that SinkCache supports more than one new token at once, when shifting the cache"""
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"HuggingFaceH4/zephyr-7b-beta"
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
"HuggingFaceH4/zephyr-7b-beta"
,
device_map
=
"auto"
,
torch_dtype
=
torch
.
float16
)
prompt
=
(
"Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences "
"and must-see attractions."
)
# Prepare generation settings
cache
=
SinkCache
(
window_length
=
256
,
num_sink_tokens
=
4
)
input_ids
=
torch
.
tensor
([],
device
=
model
.
device
,
dtype
=
torch
.
int
)
for
_
in
range
(
3
):
# Tokenize the prompt with the correct chat template
chat
=
[{
"role"
:
"user"
,
"content"
:
prompt
}]
tokenized_chat
=
tokenizer
.
apply_chat_template
(
chat
,
return_tensors
=
"pt"
,
add_generation_prompt
=
True
).
to
(
model
.
device
)
input_ids
=
torch
.
cat
((
input_ids
,
tokenized_chat
),
dim
=
1
)
# Perform the generation
gen_out
=
model
.
generate
(
input_ids
,
do_sample
=
False
,
max_new_tokens
=
100
,
past_key_values
=
cache
,
use_cache
=
True
)
input_ids
=
gen_out
# We went well beyond the cache length
self
.
assertTrue
(
input_ids
.
shape
[
1
]
>
cache
.
get_max_length
()
*
1.5
)
# And it still produces a coherent english
decoded
=
tokenizer
.
batch_decode
(
input_ids
,
skip_special_tokens
=
True
)
last_output
=
(
"<|assistant|>
\n
As the sun began to set over the Pacific Ocean, I found myself standing on the shores of "
"Waikiki Beach, my heart filled with awe and wonder. I had just returned from a two-week journey to the "
"beautiful island of Hawaii, and it had been an unforgettable experience filled with cultural experiences "
"and must-see attractions that left me breathless.
\n\n
One of the most memorable experiences of my trip "
"was visiting the historic district of Honolulu. Here,"
)
self
.
assertTrue
(
decoded
[
0
].
endswith
(
last_output
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment