Unverified Commit 1f07200a authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

FIX: Add safe guards for static cache + llama on transformers latest (#401)

parent 5d7b0502
...@@ -188,17 +188,20 @@ class QuantAttentionFused(nn.Module): ...@@ -188,17 +188,20 @@ class QuantAttentionFused(nn.Module):
# Always reset to 0 # Always reset to 0
self.start_pos = 0 self.start_pos = 0
hf_is_generating = False
if self.is_hf_transformers and "use_cache" in kwargs:
hf_is_generating = kwargs["use_cache"]
# In case we re-generate, we need to refresh the starting position # In case we re-generate, we need to refresh the starting position
# to 0. We detect it by checking if `past_key_values` is set to None, # to 0. We detect it by checking if `past_key_values` is set to None,
# which indicates that we are on the first step of `generate()`. # which indicates that we are on the first step of `generate()`.
# This is only applicable for `transformers` integration # This is only applicable for `transformers` integration
if ( if (self.is_hf_transformers and "past_key_value" in kwargs and kwargs["past_key_value"] is None) or (self.is_hf_transformers and not hf_is_generating):
self.is_hf_transformers
and "past_key_value" in kwargs
and kwargs["past_key_value"] is None
):
self.start_pos = 0 self.start_pos = 0
xqkv = self.qkv_proj(hidden_states) xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"]) xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
...@@ -214,8 +217,6 @@ class QuantAttentionFused(nn.Module): ...@@ -214,8 +217,6 @@ class QuantAttentionFused(nn.Module):
if not self.use_alibi: if not self.use_alibi:
xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen) xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen)
self.cache.to(xq)
values_store = xv.transpose(2, 1) values_store = xv.transpose(2, 1)
keys_store = ( keys_store = (
xk.reshape((bsz, seqlen) + self.attention_shapes["xk_reshape"]) xk.reshape((bsz, seqlen) + self.attention_shapes["xk_reshape"])
...@@ -223,6 +224,7 @@ class QuantAttentionFused(nn.Module): ...@@ -223,6 +224,7 @@ class QuantAttentionFused(nn.Module):
.contiguous() .contiguous()
) )
self.cache.to(xq)
self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen) self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen)
# Only necessary to retrieve from cache when we are not processing context # Only necessary to retrieve from cache when we are not processing context
...@@ -248,6 +250,11 @@ class QuantAttentionFused(nn.Module): ...@@ -248,6 +250,11 @@ class QuantAttentionFused(nn.Module):
# When seqlen is 1, there is nothing else to attend to # When seqlen is 1, there is nothing else to attend to
if attention_mask is not None and seqlen > 1: if attention_mask is not None and seqlen > 1:
# For llama-arch, the causal mask is preallocated with bsz x 1 x max_seq_len x max_seq_len, thus we
# need to slice it
if attention_mask.shape[-1] != seqlen:
attention_mask = attention_mask[:, :, :seqlen, :seqlen]
scores = ( scores = (
scores + attention_mask scores + attention_mask
) # (bs, n_local_heads, slen, cache_len + slen) ) # (bs, n_local_heads, slen, cache_len + slen)
...@@ -278,11 +285,15 @@ class QuantAttentionFused(nn.Module): ...@@ -278,11 +285,15 @@ class QuantAttentionFused(nn.Module):
attn_output = self.o_proj(attention_weight) attn_output = self.o_proj(attention_weight)
self.start_pos += seqlen self.start_pos += seqlen
if self.is_hf_transformers and not hf_is_generating:
self.start_pos = 0
# past_key_value is replaced with cache_v, cache_k, returning empty data # past_key_value is replaced with cache_v, cache_k, returning empty data
# we pass a dummy past kv cache for transformers to be able to retrieve the correct info # we pass a dummy past kv cache for transformers to be able to retrieve the correct info
# about past key length # about past key length
past_key_value = [torch.zeros(1, 1, self.start_pos, 1)] past_key_value = [torch.zeros(1, 1, self.start_pos, 1)]
if HF_NEW_CACHE_FORMAT and self.is_hf_transformers: if HF_NEW_CACHE_FORMAT and self.is_hf_transformers:
new_cache = DynamicCache() new_cache = DynamicCache()
new_cache.update(past_key_value[0], past_key_value[0], layer_idx=0) new_cache.update(past_key_value[0], past_key_value[0], layer_idx=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment