Unverified Commit 0ee5ccda authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Relax the contiguous check for flash attention (#1176)



* relax contiguous check for flash attention
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* force contiguous for cp
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
parent c0caadbe
......@@ -4881,20 +4881,19 @@ class FlashAttention(torch.nn.Module):
)
else:
query_layer, key_layer, value_layer = [
x.transpose(0, 1).contiguous()
for x in (query_layer, key_layer, value_layer)
x.transpose(0, 1) for x in (query_layer, key_layer, value_layer)
]
elif qkv_format in ["bshd", "thd"]:
if context_parallel:
query_layer, key_layer, value_layer = [
x.contiguous() for x in (query_layer, key_layer, value_layer)
]
else:
if qkv_format == "sbhd":
query_layer._data, key_layer._data, value_layer._data = [
x.transpose(0, 1).contiguous()
x.transpose(0, 1)
for x in (query_layer._data, key_layer._data, value_layer._data)
]
elif qkv_format in ["bshd", "thd"]:
if context_parallel:
query_layer._data, key_layer._data, value_layer._data = [
x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data)
]
......@@ -5092,11 +5091,7 @@ class FlashAttention(torch.nn.Module):
output.reshape(batch_size * max_seqlen_q // cp_size, -1).transpose_2d()
output = output.reshape(batch_size, max_seqlen_q // cp_size, -1)
else:
output = (
output.view(batch_size, max_seqlen_q // cp_size, -1)
.transpose(0, 1)
.contiguous()
)
output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1)
elif qkv_format == "bshd":
# (bs)hd -> bs(hd)
output = output.reshape(batch_size, max_seqlen_q // cp_size, -1)
......@@ -5104,7 +5099,7 @@ class FlashAttention(torch.nn.Module):
# thd -> t(hd)
output = output.reshape(output.shape[0], -1)
return output
return output.contiguous()
def _combine_tensors(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment