Unverified Commit 5413b898 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

KV cache is no longer a model attribute (#30730)

kv_cache is no longer a model attribute
parent 218f4413
...@@ -271,7 +271,6 @@ class CohereAttention(nn.Module): ...@@ -271,7 +271,6 @@ class CohereAttention(nn.Module):
key_states = key_states.transpose(1, 2) key_states = key_states.transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
past_key_value = getattr(self, "past_key_value", past_key_value)
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
...@@ -365,8 +364,6 @@ class CohereFlashAttention2(CohereAttention): ...@@ -365,8 +364,6 @@ class CohereFlashAttention2(CohereAttention):
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None: if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache # sin and cos are specific to RoPE models; position_ids needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
...@@ -571,9 +568,6 @@ class CohereSdpaAttention(CohereAttention): ...@@ -571,9 +568,6 @@ class CohereSdpaAttention(CohereAttention):
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# In case static cache is used, it is an instance attribute.
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None: if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache # sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
......
...@@ -287,7 +287,6 @@ class DbrxAttention(nn.Module): ...@@ -287,7 +287,6 @@ class DbrxAttention(nn.Module):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
past_key_value = getattr(self, "past_key_value", past_key_value)
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
...@@ -387,8 +386,6 @@ class DbrxFlashAttention2(DbrxAttention): ...@@ -387,8 +386,6 @@ class DbrxFlashAttention2(DbrxAttention):
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None: if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache # sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
...@@ -600,8 +597,6 @@ class DbrxSdpaAttention(DbrxAttention): ...@@ -600,8 +597,6 @@ class DbrxSdpaAttention(DbrxAttention):
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None: if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache # sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
......
...@@ -262,7 +262,6 @@ class GemmaAttention(nn.Module): ...@@ -262,7 +262,6 @@ class GemmaAttention(nn.Module):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
past_key_value = getattr(self, "past_key_value", past_key_value)
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
...@@ -353,8 +352,6 @@ class GemmaFlashAttention2(GemmaAttention): ...@@ -353,8 +352,6 @@ class GemmaFlashAttention2(GemmaAttention):
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None: if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache # sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
...@@ -552,8 +549,6 @@ class GemmaSdpaAttention(GemmaAttention): ...@@ -552,8 +549,6 @@ class GemmaSdpaAttention(GemmaAttention):
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None: if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache # sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
......
...@@ -356,7 +356,6 @@ class LlamaAttention(nn.Module): ...@@ -356,7 +356,6 @@ class LlamaAttention(nn.Module):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
past_key_value = getattr(self, "past_key_value", past_key_value)
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
...@@ -452,8 +451,6 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -452,8 +451,6 @@ class LlamaFlashAttention2(LlamaAttention):
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None: if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache # sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
...@@ -650,9 +647,6 @@ class LlamaSdpaAttention(LlamaAttention): ...@@ -650,9 +647,6 @@ class LlamaSdpaAttention(LlamaAttention):
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# In case static cache is used, it is an instance attribute.
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None: if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache # sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
......
...@@ -328,7 +328,6 @@ class OlmoAttention(nn.Module): ...@@ -328,7 +328,6 @@ class OlmoAttention(nn.Module):
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
past_key_value = getattr(self, "past_key_value", past_key_value)
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
...@@ -419,8 +418,6 @@ class OlmoFlashAttention2(OlmoAttention): ...@@ -419,8 +418,6 @@ class OlmoFlashAttention2(OlmoAttention):
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None: if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache # sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
...@@ -624,9 +621,6 @@ class OlmoSdpaAttention(OlmoAttention): ...@@ -624,9 +621,6 @@ class OlmoSdpaAttention(OlmoAttention):
cos, sin = self.rotary_emb(value_states, position_ids) cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# In case static cache is used, it is an instance attribute.
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None: if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache # sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment