Unverified Commit 490a5f41 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Fix attention backend and tests for `sm120` (#2320)



* Fix attention backend and tests for sm120
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Disable MLA only for backward
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 5e8a9a96
......@@ -61,8 +61,16 @@ from utils import (
get_available_attention_backends,
)
# Check if hardware supports FP8
# Check if hardware supports FP8 attention.
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
fp8_attn_available, reason_for_no_fp8_attn = fp8_available, reason_for_no_fp8
device_compute_capability = get_device_compute_capability()
if fp8_available and (device_compute_capability < (9, 0) or device_compute_capability >= (12, 0)):
fp8_attn_available = False
reason_for_no_fp8_attn = (
"FP8 attention is not supported for compute capability ="
f" sm{device_compute_capability[0] * 10 + device_compute_capability[1]}"
)
# Reset RNG seed and states
seed = 1234
......@@ -1573,8 +1581,7 @@ model_configs_fp8_extra_state = {
}
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.")
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.")
@pytest.mark.parametrize("model", ["large"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
......@@ -1736,8 +1743,7 @@ qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16)
......@@ -1973,8 +1979,7 @@ def _run_mha_fp8_vs_f16(
@pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16)
......@@ -2302,8 +2307,7 @@ models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
),
reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""",
)
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0)
def test_custom_mha_fp8_vs_f16(dtype, model):
......
......@@ -481,6 +481,20 @@ def get_attention_backend(
logger.debug("Disabling FusedAttention for FP8 current scaling with cuDNN < 9.14.0")
use_fused_attention = False
if device_compute_capability == (12, 0):
if use_flash_attention:
logger.debug(
"Disabling FlashAttention as FP8 is not supported"
" for compute capability = sm120"
)
if use_fused_attention:
logger.debug(
"Disabling FusedAttention as FP8 is not supported"
" for compute capability = sm120"
)
use_flash_attention = False
use_fused_attention = False
# Filter: Return max_logit
if return_max_logit:
if use_flash_attention:
......@@ -560,6 +574,20 @@ def get_attention_backend(
qkv_layout,
)
use_fused_attention = False
if (
device_compute_capability == (12, 0)
and (head_dim_qk > 128 or head_dim_qk % 8 != 0)
and is_training
):
if use_fused_attention:
logger.debug(
"Disabling FusedAttention as MLA for backward pass is not supported for compute"
" capability = sm120 for a head_dim_qk > 128 or head_dim_qk %%8 != 0. Found:"
" head_dim_qk = %s",
head_dim_qk,
)
use_fused_attention = False
if use_flash_attention_2 and (
head_dim_qk > 256
or head_dim_qk % 8 != 0
......@@ -629,6 +657,13 @@ def get_attention_backend(
"padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
)
use_flash_attention = False
if device_compute_capability == (12, 0):
if use_fused_attention:
logger.debug(
"Disabling FusedAttention as qkv_format = thd is"
" not supported for compute capability = sm120"
)
use_fused_attention = False
# Filter: Dropout
if attention_dropout != 0.0 and use_flash_attention_3:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment