Unverified Commit f0ddab82 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Relax FA 2.0 checks for Ada (#331)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 10eb13e2
...@@ -885,10 +885,15 @@ class DotProductAttention(torch.nn.Module): ...@@ -885,10 +885,15 @@ class DotProductAttention(torch.nn.Module):
if (query_layer.dtype not in [torch.bfloat16, torch.float16] if (query_layer.dtype not in [torch.bfloat16, torch.float16]
or key_layer.dtype not in [torch.bfloat16, torch.float16] or key_layer.dtype not in [torch.bfloat16, torch.float16]
or value_layer.dtype not in [torch.bfloat16, torch.float16] or value_layer.dtype not in [torch.bfloat16, torch.float16]
or (self.device_compute_capability in (8.6, 8.7, 8.9) and key_layer.shape[-1] > 64)
): ):
use_flash_attention = False use_flash_attention = False
if key_layer.shape[-1] > 64:
if self.device_compute_capability in (8.6, 8.7):
use_flash_attention = False
elif not _flash_attn_2_available and self.device_compute_capability == 8.9:
use_flash_attention = False
if self.attn_mask_type == "padding" and attention_mask is not None: if self.attn_mask_type == "padding" and attention_mask is not None:
use_flash_attention = False use_flash_attention = False
use_fused_attention = False use_fused_attention = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment