Unverified Commit 0ee5ccda authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Relax the contiguous check for flash attention (#1176)



* relax contiguous check for flash attention
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* force contiguous for cp
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
parent c0caadbe
...@@ -4881,20 +4881,19 @@ class FlashAttention(torch.nn.Module): ...@@ -4881,20 +4881,19 @@ class FlashAttention(torch.nn.Module):
) )
else: else:
query_layer, key_layer, value_layer = [ query_layer, key_layer, value_layer = [
x.transpose(0, 1).contiguous() x.transpose(0, 1) for x in (query_layer, key_layer, value_layer)
for x in (query_layer, key_layer, value_layer)
] ]
elif qkv_format in ["bshd", "thd"]: if context_parallel:
query_layer, key_layer, value_layer = [ query_layer, key_layer, value_layer = [
x.contiguous() for x in (query_layer, key_layer, value_layer) x.contiguous() for x in (query_layer, key_layer, value_layer)
] ]
else: else:
if qkv_format == "sbhd": if qkv_format == "sbhd":
query_layer._data, key_layer._data, value_layer._data = [ query_layer._data, key_layer._data, value_layer._data = [
x.transpose(0, 1).contiguous() x.transpose(0, 1)
for x in (query_layer._data, key_layer._data, value_layer._data) for x in (query_layer._data, key_layer._data, value_layer._data)
] ]
elif qkv_format in ["bshd", "thd"]: if context_parallel:
query_layer._data, key_layer._data, value_layer._data = [ query_layer._data, key_layer._data, value_layer._data = [
x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data) x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data)
] ]
...@@ -5092,11 +5091,7 @@ class FlashAttention(torch.nn.Module): ...@@ -5092,11 +5091,7 @@ class FlashAttention(torch.nn.Module):
output.reshape(batch_size * max_seqlen_q // cp_size, -1).transpose_2d() output.reshape(batch_size * max_seqlen_q // cp_size, -1).transpose_2d()
output = output.reshape(batch_size, max_seqlen_q // cp_size, -1) output = output.reshape(batch_size, max_seqlen_q // cp_size, -1)
else: else:
output = ( output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1)
output.view(batch_size, max_seqlen_q // cp_size, -1)
.transpose(0, 1)
.contiguous()
)
elif qkv_format == "bshd": elif qkv_format == "bshd":
# (bs)hd -> bs(hd) # (bs)hd -> bs(hd)
output = output.reshape(batch_size, max_seqlen_q // cp_size, -1) output = output.reshape(batch_size, max_seqlen_q // cp_size, -1)
...@@ -5104,7 +5099,7 @@ class FlashAttention(torch.nn.Module): ...@@ -5104,7 +5099,7 @@ class FlashAttention(torch.nn.Module):
# thd -> t(hd) # thd -> t(hd)
output = output.reshape(output.shape[0], -1) output = output.reshape(output.shape[0], -1)
return output return output.contiguous()
def _combine_tensors( def _combine_tensors(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment