Unverified Commit 71b2dd48 authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

Fix cudnn versioning support in PyTorch DPA and Fused attn (#1991)



Fix cudnn versioning in support in PyTorch DPA and Fused attn
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>
parent ee841084
......@@ -251,10 +251,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged
(head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 &&
cudnn_runtime_version >= 91100)) &&
// 9.11 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
(!(cudnn_runtime_version == 91100 && is_training && sm_arch_ == 90 && head_dim_qk >= 128 &&
head_dim_v >= 128 && !(head_dim_qk == 192 && head_dim_v == 128) &&
head_dim_qk != head_dim_v))) &&
// 9.11/9.12 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
(!((cudnn_runtime_version == 91100 || cudnn_runtime_version == 91200) && is_training &&
sm_arch_ == 90 && head_dim_qk >= 128 && head_dim_v >= 128 &&
!(head_dim_qk == 192 && head_dim_v == 128) && head_dim_qk != head_dim_v))) &&
// bias type
((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
(cudnn_runtime_version >= 8906 &&
......
......@@ -434,8 +434,8 @@ def get_attention_backend(
# | FP8 | non-paged/paged | sm90 | thd | >= 1
# Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1
if inference_params is not None:
if device_compute_capability == (8, 9) and cudnn_version < (9, 12, 0):
logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN < 9.12")
if device_compute_capability == (8, 9) and cudnn_version <= (9, 12, 0):
logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN <= 9.12")
use_fused_attention = False
if context_parallel:
logger.debug("Disabling all backends for KV caching with context parallelism")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment