"qa/vscode:/vscode.git/clone" did not exist on "43569381ce3779f9bf6084da917556db210d745d"
Unverified Commit 5a881a08 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Catch FA internal error with compute capability 8.6 (#113)



FA doesn't support compute 8.6 with head_dim>64
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 5f0d3868
......@@ -353,10 +353,11 @@ class DotProductAttention(torch.nn.Module):
norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.device_compute_capability = get_device_compute_capability()
self.use_flash_attention = (
int(os.getenv("NVTE_FLASH_ATTN", "1"))
and attn_mask_type == "causal"
and get_device_compute_capability() >= 8.0
and self.device_compute_capability >= 8.0
)
attn_kwargs = {
......@@ -437,6 +438,7 @@ class DotProductAttention(torch.nn.Module):
if (query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16]
or (self.device_compute_capability == 8.6 and key_layer.shape[-1] > 64)
):
use_flash_attention = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment