Unverified Commit 7f77127c authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

Fix cuDNN version checks when getting backend and for sm89 kv cache (#2185)



* Fix cudnn version checks for kv cache for sm89. Add cudnn version check in preparation for 9.14 when getting backend
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Minor fix for cuDNN version condition check
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent c334fc46
...@@ -251,11 +251,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -251,11 +251,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged // 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged
(head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 && (head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 &&
cudnn_runtime_version >= 91100)) && cudnn_runtime_version >= 91100)) &&
// 9.11/9.12 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA // 9.11+ bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
(!((cudnn_runtime_version == 91100 || cudnn_runtime_version == 91200 || // Conditional to temporarily use blanket cudnn_runtime_version >= 9.11 until fixed
cudnn_runtime_version == 91300) && (!((cudnn_runtime_version >= 91100) && is_training && sm_arch_ == 90 &&
is_training && sm_arch_ == 90 && head_dim_qk >= 128 && head_dim_v >= 128 && head_dim_qk >= 128 && head_dim_v >= 128 && !(head_dim_qk == 192 && head_dim_v == 128) &&
!(head_dim_qk == 192 && head_dim_v == 128) && head_dim_qk != head_dim_v))) && head_dim_qk != head_dim_v))) &&
// bias type // bias type
((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
(cudnn_runtime_version >= 8906 && (cudnn_runtime_version >= 8906 &&
......
...@@ -434,8 +434,10 @@ def get_attention_backend( ...@@ -434,8 +434,10 @@ def get_attention_backend(
# | FP8 | non-paged/paged | sm90 | thd | >= 1 # | FP8 | non-paged/paged | sm90 | thd | >= 1
# Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1 # Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1
if inference_params is not None: if inference_params is not None:
if device_compute_capability == (8, 9) and cudnn_version <= (9, 13, 0): # Temporarily disabling fused attention for kv caching for sm89 irrespective of cuDNN version
logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN <= 9.13") # until the cuDNN bug is resolved
if device_compute_capability == (8, 9):
logger.debug("Disabling FusedAttention for KV caching for sm89")
use_fused_attention = False use_fused_attention = False
if context_parallel: if context_parallel:
logger.debug("Disabling all backends for KV caching with context parallelism") logger.debug("Disabling all backends for KV caching with context parallelism")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment