Unverified Commit cc0cb35d authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Skip KV cache for sm89 and cuDNN < 9.12 (#1895)



* skip kv cache for sm89, cudnn < 9.12
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix test_numerics
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 5b16807c
......@@ -2322,9 +2322,9 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
if (
backend == "FusedAttention"
and get_device_compute_capability() == (8, 9)
and get_cudnn_version() < (9, 11, 0)
and get_cudnn_version() < (9, 12, 0)
):
pytest.skip("Skip KV cache for sm89 and cuDNN < 9.11")
pytest.skip("Skip KV cache for sm89 and cuDNN < 9.12")
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
......
......@@ -433,8 +433,8 @@ def get_attention_backend(
# | FP8 | non-paged/paged | sm90 | thd | >= 1
# Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1
if inference_params is not None:
if device_compute_capability == (8, 9) and cudnn_version < (9, 11, 0):
logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN < 9.11")
if device_compute_capability == (8, 9) and cudnn_version < (9, 12, 0):
logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN < 9.12")
use_fused_attention = False
if context_parallel:
logger.debug("Disabling all backends for KV caching with context parallelism")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment