Unverified Commit 2a81e939 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

FA does not support head_dim > 64 on Ada (#328)



* FA does not support head_dim > 64 on Ada
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add cc8.7 to no FA list
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent a3e4e611
...@@ -879,7 +879,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -879,7 +879,7 @@ class DotProductAttention(torch.nn.Module):
if (query_layer.dtype not in [torch.bfloat16, torch.float16] if (query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16] or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16] or value_layer.dtype not in [torch.bfloat16, torch.float16]
or (self.device_compute_capability == 8.6 and key_layer.shape[-1] > 64) or (self.device_compute_capability in (8.6, 8.7, 8.9) and key_layer.shape[-1] > 64)
): ):
use_flash_attention = False use_flash_attention = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment