"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "08b41b413a4065ec03a57016b1e75bd44302ee8b"
Unverified Commit ce0bbd51 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: SinkCache can handle iterative prompts (#27907)

parent 94c76538
...@@ -38,6 +38,21 @@ class Cache: ...@@ -38,6 +38,21 @@ class Cache:
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" """Returns the sequence length of the cached states. A layer index can be optionally passed."""
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states, if there is any."""
raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
# Cache without size limit -> all cache is usable
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
# length, we will need to evict part of the cache (and thus not all cache is usable)
max_length = self.get_max_length()
previous_seq_length = self.get_seq_length(layer_idx)
if max_length is not None and previous_seq_length + new_seq_length > max_length:
return max_length - new_seq_length
return previous_seq_length
class DynamicCache(Cache): class DynamicCache(Cache):
""" """
...@@ -120,6 +135,10 @@ class DynamicCache(Cache): ...@@ -120,6 +135,10 @@ class DynamicCache(Cache):
return 0 return 0
return self.key_cache[layer_idx].shape[-2] return self.key_cache[layer_idx].shape[-2]
def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
return None
def reorder_cache(self, beam_idx: torch.LongTensor): def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices.""" """Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)): for layer_idx in range(len(self.key_cache)):
...@@ -209,8 +228,11 @@ class SinkCache(Cache): ...@@ -209,8 +228,11 @@ class SinkCache(Cache):
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
if len(self.key_cache) <= layer_idx: if len(self.key_cache) <= layer_idx:
return 0 return 0
cache_length = self.key_cache[layer_idx].shape[-2] return self.key_cache[layer_idx].shape[-2]
return min(cache_length, self.window_length - 1)
def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states."""
return self.window_length
def update( def update(
self, self,
...@@ -267,7 +289,9 @@ class SinkCache(Cache): ...@@ -267,7 +289,9 @@ class SinkCache(Cache):
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
if using_rope: if using_rope:
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(key_states, cos, sin) rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
key_states, cos[: self.window_length], sin[: self.window_length]
)
if partial_rotation_size is not None: if partial_rotation_size is not None:
keys_to_keep, keys_pass = ( keys_to_keep, keys_pass = (
keys_to_keep[..., :partial_rotation_size], keys_to_keep[..., :partial_rotation_size],
......
...@@ -398,7 +398,7 @@ class LlamaAttention(nn.Module): ...@@ -398,7 +398,7 @@ class LlamaAttention(nn.Module):
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index." "with a layer index."
) )
kv_seq_len += past_key_value.get_seq_length(self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
...@@ -503,7 +503,7 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -503,7 +503,7 @@ class LlamaFlashAttention2(LlamaAttention):
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value.get_seq_length(self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
...@@ -910,7 +910,7 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -910,7 +910,7 @@ class LlamaModel(LlamaPreTrainedModel):
use_legacy_cache = not isinstance(past_key_values, Cache) use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache: if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_seq_length() past_key_values_length = past_key_values.get_usable_length(seq_length)
if position_ids is None: if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
...@@ -1127,8 +1127,10 @@ class LlamaForCausalLM(LlamaPreTrainedModel): ...@@ -1127,8 +1127,10 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
if isinstance(past_key_values, Cache): if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length() cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else: else:
cache_length = past_length = past_key_values[0][0].shape[2] cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...@@ -1142,10 +1144,13 @@ class LlamaForCausalLM(LlamaPreTrainedModel): ...@@ -1142,10 +1144,13 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
input_ids = input_ids[:, past_length:] input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
# older attention values, as their corresponding values are not part of the input. if (
if cache_length < past_length and attention_mask is not None: max_cache_length is not None
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None: if attention_mask is not None and position_ids is None:
......
...@@ -268,7 +268,7 @@ class MistralAttention(nn.Module): ...@@ -268,7 +268,7 @@ class MistralAttention(nn.Module):
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index." "with a layer index."
) )
kv_seq_len += past_key_value.get_seq_length(self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
...@@ -363,7 +363,7 @@ class MistralFlashAttention2(MistralAttention): ...@@ -363,7 +363,7 @@ class MistralFlashAttention2(MistralAttention):
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value.get_seq_length(self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id. # Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
...@@ -850,15 +850,13 @@ class MistralModel(MistralPreTrainedModel): ...@@ -850,15 +850,13 @@ class MistralModel(MistralPreTrainedModel):
else: else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0 past_key_values_length = 0
if use_cache: if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache) use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache: if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_seq_length() past_key_values_length = past_key_values.get_usable_length(seq_length)
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None: if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
...@@ -1092,8 +1090,10 @@ class MistralForCausalLM(MistralPreTrainedModel): ...@@ -1092,8 +1090,10 @@ class MistralForCausalLM(MistralPreTrainedModel):
if isinstance(past_key_values, Cache): if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length() cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else: else:
cache_length = past_length = past_key_values[0][0].shape[2] cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...@@ -1107,10 +1107,13 @@ class MistralForCausalLM(MistralPreTrainedModel): ...@@ -1107,10 +1107,13 @@ class MistralForCausalLM(MistralPreTrainedModel):
input_ids = input_ids[:, past_length:] input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
# older attention values, as their corresponding values are not part of the input. if (
if cache_length < past_length and attention_mask is not None: max_cache_length is not None
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None: if attention_mask is not None and position_ids is None:
......
...@@ -295,7 +295,7 @@ class PersimmonAttention(nn.Module): ...@@ -295,7 +295,7 @@ class PersimmonAttention(nn.Module):
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index." "with a layer index."
) )
kv_seq_len += past_key_value.get_seq_length(self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
# Partial rotary embedding # Partial rotary embedding
...@@ -612,7 +612,7 @@ class PersimmonModel(PersimmonPreTrainedModel): ...@@ -612,7 +612,7 @@ class PersimmonModel(PersimmonPreTrainedModel):
use_legacy_cache = not isinstance(past_key_values, Cache) use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache: if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_seq_length() past_key_values_length = past_key_values.get_usable_length(seq_length)
seq_length_with_past = seq_length_with_past + past_key_values_length seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None: if position_ids is None:
...@@ -831,8 +831,10 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel): ...@@ -831,8 +831,10 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
if isinstance(past_key_values, Cache): if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length() cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else: else:
cache_length = past_length = past_key_values[0][0].shape[2] cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...@@ -846,10 +848,13 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel): ...@@ -846,10 +848,13 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
input_ids = input_ids[:, past_length:] input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
# older attention values, as their corresponding values are not part of the input. if (
if cache_length < past_length and attention_mask is not None: max_cache_length is not None
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None: if attention_mask is not None and position_ids is None:
......
...@@ -334,7 +334,7 @@ class PhiAttention(nn.Module): ...@@ -334,7 +334,7 @@ class PhiAttention(nn.Module):
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index." "with a layer index."
) )
kv_seq_len += past_key_value.get_seq_length(self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
# Partial rotary embedding # Partial rotary embedding
...@@ -444,7 +444,7 @@ class PhiFlashAttention2(PhiAttention): ...@@ -444,7 +444,7 @@ class PhiFlashAttention2(PhiAttention):
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value.get_seq_length(self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
# Partial rotary embedding # Partial rotary embedding
...@@ -855,15 +855,13 @@ class PhiModel(PhiPreTrainedModel): ...@@ -855,15 +855,13 @@ class PhiModel(PhiPreTrainedModel):
else: else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0 past_key_values_length = 0
if use_cache: if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache) use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache: if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_seq_length() past_key_values_length = past_key_values.get_usable_length(seq_length)
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None: if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
...@@ -1085,8 +1083,10 @@ class PhiForCausalLM(PhiPreTrainedModel): ...@@ -1085,8 +1083,10 @@ class PhiForCausalLM(PhiPreTrainedModel):
if isinstance(past_key_values, Cache): if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length() cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else: else:
cache_length = past_length = past_key_values[0][0].shape[2] cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...@@ -1100,10 +1100,13 @@ class PhiForCausalLM(PhiPreTrainedModel): ...@@ -1100,10 +1100,13 @@ class PhiForCausalLM(PhiPreTrainedModel):
input_ids = input_ids[:, past_length:] input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
# older attention values, as their corresponding values are not part of the input. if (
if cache_length < past_length and attention_mask is not None: max_cache_length is not None
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None: if attention_mask is not None and position_ids is None:
......
...@@ -187,3 +187,45 @@ class CacheIntegrationTest(unittest.TestCase): ...@@ -187,3 +187,45 @@ class CacheIntegrationTest(unittest.TestCase):
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache) gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network")) self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network"))
def test_sink_cache_iterative_prompts(self):
"""Tests that SinkCache supports more than one new token at once, when shifting the cache"""
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
model = AutoModelForCausalLM.from_pretrained(
"HuggingFaceH4/zephyr-7b-beta", device_map="auto", torch_dtype=torch.float16
)
prompt = (
"Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences "
"and must-see attractions."
)
# Prepare generation settings
cache = SinkCache(window_length=256, num_sink_tokens=4)
input_ids = torch.tensor([], device=model.device, dtype=torch.int)
for _ in range(3):
# Tokenize the prompt with the correct chat template
chat = [{"role": "user", "content": prompt}]
tokenized_chat = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
model.device
)
input_ids = torch.cat((input_ids, tokenized_chat), dim=1)
# Perform the generation
gen_out = model.generate(
input_ids, do_sample=False, max_new_tokens=100, past_key_values=cache, use_cache=True
)
input_ids = gen_out
# We went well beyond the cache length
self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5)
# And it still produces a coherent english
decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
last_output = (
"<|assistant|>\nAs the sun began to set over the Pacific Ocean, I found myself standing on the shores of "
"Waikiki Beach, my heart filled with awe and wonder. I had just returned from a two-week journey to the "
"beautiful island of Hawaii, and it had been an unforgettable experience filled with cultural experiences "
"and must-see attractions that left me breathless.\n\nOne of the most memorable experiences of my trip "
"was visiting the historic district of Honolulu. Here,"
)
self.assertTrue(decoded[0].endswith(last_output))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment