Unverified Commit 609a1767 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`CLeanup`] Revert SDPA attention changes that got in the static kv cache PR (#29027)

* revert unrelated changes that got in

* style
parent 7a0fccc6
...@@ -659,34 +659,28 @@ class MistralSdpaAttention(MistralAttention): ...@@ -659,34 +659,28 @@ class MistralSdpaAttention(MistralAttention):
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
past_seen_tokens = kv_seq_len - key_states.shape[-2]
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
if past_key_value is not None: if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
if ( if attention_mask is not None:
attention_mask is not None and not torch.all(attention_mask[..., 0] == 1) and q_len != 1 if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
): # user defined causal mask raise ValueError(
causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]] f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
# this one liner is equivalent to the pad_unpad function )
causal_mask.mul_(~torch.eq(causal_mask, causal_mask.min()).all(dim=-1)[..., None])
else:
causal_mask = None
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577. # Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None: if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous() query_states = query_states.contiguous()
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
...@@ -695,9 +689,10 @@ class MistralSdpaAttention(MistralAttention): ...@@ -695,9 +689,10 @@ class MistralSdpaAttention(MistralAttention):
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=causal_mask, attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0, dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1, # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
......
...@@ -736,34 +736,28 @@ class MixtralSdpaAttention(MixtralAttention): ...@@ -736,34 +736,28 @@ class MixtralSdpaAttention(MixtralAttention):
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
past_seen_tokens = kv_seq_len - key_states.shape[-2]
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
if past_key_value is not None: if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
if ( if attention_mask is not None:
attention_mask is not None and not torch.all(attention_mask[..., 0] == 1) and q_len != 1 if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
): # user defined causal mask raise ValueError(
causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]] f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
# this one liner is equivalent to the pad_unpad function )
causal_mask.mul_(~torch.eq(causal_mask, causal_mask.min()).all(dim=-1)[..., None])
else:
causal_mask = None
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577. # Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None: if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous() query_states = query_states.contiguous()
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
...@@ -772,9 +766,10 @@ class MixtralSdpaAttention(MixtralAttention): ...@@ -772,9 +766,10 @@ class MixtralSdpaAttention(MixtralAttention):
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=causal_mask, attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0, dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1, # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
......
...@@ -669,34 +669,28 @@ class Qwen2SdpaAttention(Qwen2Attention): ...@@ -669,34 +669,28 @@ class Qwen2SdpaAttention(Qwen2Attention):
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
past_seen_tokens = kv_seq_len - key_states.shape[-2]
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
if past_key_value is not None: if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
if ( if attention_mask is not None:
attention_mask is not None and not torch.all(attention_mask[..., 0] == 1) and q_len != 1 if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
): # user defined causal mask raise ValueError(
causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]] f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
# this one liner is equivalent to the pad_unpad function )
causal_mask.mul_(~torch.eq(causal_mask, causal_mask.min()).all(dim=-1)[..., None])
else:
causal_mask = None
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577. # Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None: if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous() query_states = query_states.contiguous()
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
...@@ -705,9 +699,10 @@ class Qwen2SdpaAttention(Qwen2Attention): ...@@ -705,9 +699,10 @@ class Qwen2SdpaAttention(Qwen2Attention):
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=causal_mask, attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0, dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=causal_mask is None and q_len > 1, # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment