Unverified Commit 78172dcd authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Fix SDPA correctness following torch==2.1.2 regression (#27973)

* fix sdpa with non-contiguous inputs for gpt_bigcode

* fix other archs

* add currently comment

* format
parent 5e4ef0a0
......@@ -583,6 +583,8 @@ class BartSdpaAttention(BartAttention):
query_states = self._shape(query_states, tgt_len, bsz)
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
......
......@@ -447,6 +447,13 @@ class FalconAttention(nn.Module):
else:
present = None
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_layer.device.type == "cuda" and attention_mask is not None:
query_layer = query_layer.contiguous()
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()
if alibi is None:
if self._use_sdpa and not output_attentions:
attn_output = F.scaled_dot_product_attention(
......
......@@ -532,24 +532,37 @@ class GPTBigCodeSdpaAttention(GPTBigCodeAttention):
if self.multi_query:
query_length = query_shape[1]
# NOTE: Maybe there is better than this?
# SDPA requires the dimension [..., sequence_length, head_dim].
query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
# Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
key = key.unsqueeze(1)
value = value.unsqueeze(1)
# Although these expand are not numerically useful, PyTorch 2.1 can not dispatch to mem-efficient attention
# and flash attention (No available kernel. Aborting execution.) from the shapes
# Although these expand are not numerically useful, PyTorch 2.1 can not dispatch to memory-efficient backend
# and flash attention backend (No available kernel. Aborting execution.) from the shapes
# query = [batch_size, num_heads, query_length, head_dim]
# key = [batch_size, 1, past_length, head_dim]
# value = [batch_size, 1, past_length, head_dim]
# which is unfortunate. Hopefully can be improved in the future. These expand should not be too expansive as they do not do memory copy.
key = key.expand(-1, self.num_heads, -1, -1)
value = value.expand(-1, self.num_heads, -1, -1)
#
# so we could do:
#
# key = key.expand(-1, self.num_heads, -1, -1)
# value = value.expand(-1, self.num_heads, -1, -1)
#
# However SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# so we always dispatch to the math path: https://github.com/pytorch/pytorch/issues/112577.
# Arguably we could still do expand + contiguous when `query.device.type == "cuda"` in order to dispatch on memory-efficient
# backend, but it feels very hacky.
else:
query_length = query_shape[-1]
# See the comment above.
if query.device.type == "cuda" and attention_mask is not None:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query,
key,
......
......@@ -688,6 +688,13 @@ class IdeficsAttention(nn.Module):
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = nn.functional.scaled_dot_product_attention(
query_states,
key_states,
......
......@@ -506,7 +506,6 @@ class LlamaFlashAttention2(LlamaAttention):
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
......@@ -701,6 +700,7 @@ class LlamaSdpaAttention(LlamaAttention):
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
......@@ -716,6 +716,13 @@ class LlamaSdpaAttention(LlamaAttention):
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
......
......@@ -764,6 +764,8 @@ class WhisperSdpaAttention(WhisperAttention):
query_states = self._shape(query_states, tgt_len, bsz)
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment