Unverified Commit c203f527 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Disable KV cache for sm89 and cuDNN < 9.11 (#1776)



* disable sm89 and cuDNN < 9.11 for KV caching
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* disable some numerics tests
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <charleney@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 7e43feae
......@@ -42,7 +42,7 @@ from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.pytorch.utils import get_device_compute_capability, get_cudnn_version
from transformer_engine.common import recipe
import transformer_engine_torch as tex
......@@ -2293,6 +2293,12 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
pytest.skip("FusedAttention and FlashAttention do not support FP32")
if use_RoPE:
pytest.skip("KV cache does not support starting positions for RoPE")
if (
backend == "FusedAttention"
and get_device_compute_capability() == (8, 9)
and get_cudnn_version() < (9, 11, 0)
):
pytest.skip("Skip KV cache for sm89 and cuDNN < 9.11")
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
......
......@@ -433,6 +433,9 @@ def get_attention_backend(
# | FP8 | non-paged/paged | sm90 | thd | >= 1
# Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1
if inference_params is not None:
if device_compute_capability == (8, 9) and cudnn_version < (9, 11, 0):
logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN < 9.11")
use_fused_attention = False
if context_parallel:
logger.debug("Disabling all backends for KV caching with context parallelism")
use_flash_attention = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment