Unverified Commit 78172dcd authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Fix SDPA correctness following torch==2.1.2 regression (#27973)

* fix sdpa with non-contiguous inputs for gpt_bigcode

* fix other archs

* add currently comment

* format
parent 5e4ef0a0
...@@ -583,6 +583,8 @@ class BartSdpaAttention(BartAttention): ...@@ -583,6 +583,8 @@ class BartSdpaAttention(BartAttention):
query_states = self._shape(query_states, tgt_len, bsz) query_states = self._shape(query_states, tgt_len, bsz)
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
......
...@@ -447,6 +447,13 @@ class FalconAttention(nn.Module): ...@@ -447,6 +447,13 @@ class FalconAttention(nn.Module):
else: else:
present = None present = None
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_layer.device.type == "cuda" and attention_mask is not None:
query_layer = query_layer.contiguous()
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()
if alibi is None: if alibi is None:
if self._use_sdpa and not output_attentions: if self._use_sdpa and not output_attentions:
attn_output = F.scaled_dot_product_attention( attn_output = F.scaled_dot_product_attention(
......
...@@ -532,24 +532,37 @@ class GPTBigCodeSdpaAttention(GPTBigCodeAttention): ...@@ -532,24 +532,37 @@ class GPTBigCodeSdpaAttention(GPTBigCodeAttention):
if self.multi_query: if self.multi_query:
query_length = query_shape[1] query_length = query_shape[1]
# NOTE: Maybe there is better than this? # SDPA requires the dimension [..., sequence_length, head_dim].
query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2) query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
# Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions. # Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
key = key.unsqueeze(1) key = key.unsqueeze(1)
value = value.unsqueeze(1) value = value.unsqueeze(1)
# Although these expand are not numerically useful, PyTorch 2.1 can not dispatch to mem-efficient attention # Although these expand are not numerically useful, PyTorch 2.1 can not dispatch to memory-efficient backend
# and flash attention (No available kernel. Aborting execution.) from the shapes # and flash attention backend (No available kernel. Aborting execution.) from the shapes
# query = [batch_size, num_heads, query_length, head_dim] # query = [batch_size, num_heads, query_length, head_dim]
# key = [batch_size, 1, past_length, head_dim] # key = [batch_size, 1, past_length, head_dim]
# value = [batch_size, 1, past_length, head_dim] # value = [batch_size, 1, past_length, head_dim]
# which is unfortunate. Hopefully can be improved in the future. These expand should not be too expansive as they do not do memory copy. #
key = key.expand(-1, self.num_heads, -1, -1) # so we could do:
value = value.expand(-1, self.num_heads, -1, -1) #
# key = key.expand(-1, self.num_heads, -1, -1)
# value = value.expand(-1, self.num_heads, -1, -1)
#
# However SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# so we always dispatch to the math path: https://github.com/pytorch/pytorch/issues/112577.
# Arguably we could still do expand + contiguous when `query.device.type == "cuda"` in order to dispatch on memory-efficient
# backend, but it feels very hacky.
else: else:
query_length = query_shape[-1] query_length = query_shape[-1]
# See the comment above.
if query.device.type == "cuda" and attention_mask is not None:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
sdpa_result = torch.nn.functional.scaled_dot_product_attention( sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, query,
key, key,
......
...@@ -688,6 +688,13 @@ class IdeficsAttention(nn.Module): ...@@ -688,6 +688,13 @@ class IdeficsAttention(nn.Module):
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
) )
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = nn.functional.scaled_dot_product_attention( attn_output = nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
......
...@@ -506,7 +506,6 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -506,7 +506,6 @@ class LlamaFlashAttention2(LlamaAttention):
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None: if past_key_value is not None:
...@@ -701,6 +700,7 @@ class LlamaSdpaAttention(LlamaAttention): ...@@ -701,6 +700,7 @@ class LlamaSdpaAttention(LlamaAttention):
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None: if past_key_value is not None:
...@@ -716,6 +716,13 @@ class LlamaSdpaAttention(LlamaAttention): ...@@ -716,6 +716,13 @@ class LlamaSdpaAttention(LlamaAttention):
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
) )
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
......
...@@ -764,6 +764,8 @@ class WhisperSdpaAttention(WhisperAttention): ...@@ -764,6 +764,8 @@ class WhisperSdpaAttention(WhisperAttention):
query_states = self._shape(query_states, tgt_len, bsz) query_states = self._shape(query_states, tgt_len, bsz)
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states, query_states,
key_states, key_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment