"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "ffa24475a97f9223659effbcf4ccda6d1adb9a18"
Unverified Commit f70b4bbf authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Enhance fall-back conditions for fMHA. (#260)


Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
parent 144e4888
......@@ -374,7 +374,7 @@ class MultiHeadAttention(nn.Module):
use_fused_attn = not decode and not self.transpose_batch_sequence and self.fuse_qkv and \
self.dropout_rate == 0 and canonicalize_dtype in [jnp.bfloat16, jnp.float16] and \
q_seqlen in fused_attn_supported_seqlen and kv_seqlen in fused_attn_supported_seqlen \
and is_fused_attn_kernel_available() and enable_fused_attn
and is_fused_attn_kernel_available() and (self.head_dim == 64) and enable_fused_attn
if enable_fused_attn and not use_fused_attn:
reason = ""
......@@ -399,6 +399,8 @@ class MultiHeadAttention(nn.Module):
f"but got {kv_seqlen=}, "
if not is_fused_attn_kernel_available():
reason += "GPU arch >= Ampere and cuDNN >= 8.9.1 are required, "
if self.head_dim != 64:
reason += f"head_dim should be 64 but got {self.head_dim}, "
warnings.warn(
f"Fused attention is not enabled, " \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment