Unverified Commit 0f0e229b authored by Chen Cui's avatar Chen Cui Committed by GitHub
Browse files

[PyT] Update THD sink attention logic for cudnn >=9.18.0 (#2568)



* Update THD sink attention logic for newer cudnn versions

THD Sink attention is supported in 9.18.0
Signed-off-by: default avatarChen Cui <chcui@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update thd sink attention logic for cp>1
Signed-off-by: default avatarChen Cui <chcui@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add unit test for thd + sink attention
Signed-off-by: default avatarChen Cui <chcui@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* address comments
Signed-off-by: default avatarChen Cui <chcui@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* do not skip thd cp sink attention test
Signed-off-by: default avatarChen Cui <chcui@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* disable deterministic mode for sink attention
Signed-off-by: default avatarChen Cui <chcui@nvidia.com>

---------
Signed-off-by: default avatarChen Cui <chcui@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 3d46bf61
...@@ -429,6 +429,15 @@ def test_dpa_softmax(dtype, model_configs, model): ...@@ -429,6 +429,15 @@ def test_dpa_softmax(dtype, model_configs, model):
) )
@pytest.mark.skipif(get_cudnn_version() < (9, 18, 0), reason="cuDNN 9.18.0+ is required.")
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("model_configs", [model_configs_softmax])
@pytest.mark.parametrize("model", model_configs_softmax.keys())
def test_dpa_softmax_thd(dtype, model_configs, model):
"""Test DotProductAttention module with different softmax types"""
test_dot_product_attention(dtype, model_configs, model, True, True, "thd_thd_thd", False, False)
model_configs_mla = { model_configs_mla = {
# test: ModelConfig(b, sq, hq, dqk) # test: ModelConfig(b, sq, hq, dqk)
"mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128),
......
...@@ -283,9 +283,14 @@ def test_cp_with_fused_attention( ...@@ -283,9 +283,14 @@ def test_cp_with_fused_attention(
pytest.skip( pytest.skip(
"CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!" "CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!"
) )
if config.softmax_type != "vanilla" and qkv_format == "thd": if (
get_cudnn_version() < (9, 18, 0)
and config.softmax_type != "vanilla"
and qkv_format == "thd"
):
pytest.skip( pytest.skip(
"CP implementation does not support qkv_format=thd for non-vanilla softmax types!" "Unless cudnn version >= 9.18.0, CP implementation does not support qkv_format=thd for"
" non-vanilla softmax types!"
) )
dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16}
......
...@@ -4026,28 +4026,30 @@ def attn_forward_func_with_cp( ...@@ -4026,28 +4026,30 @@ def attn_forward_func_with_cp(
assert not sliding_window_attn or cp_comm_type in [ assert not sliding_window_attn or cp_comm_type in [
"a2a", "a2a",
"all_gather", "all_gather",
], "Context parallelism does not support sliding window attention with {cp_comm_type=}!" ], f"Context parallelism does not support sliding window attention with {cp_comm_type=}!"
enable_mla = k.shape[-1] != v.shape[-1] enable_mla = k.shape[-1] != v.shape[-1]
assert not enable_mla or cp_comm_type in [ assert not enable_mla or cp_comm_type in [
"p2p", "p2p",
"a2a+p2p", "a2a+p2p",
], "Context parallelism does not support MLA with {cp_comm_type=}!" ], f"Context parallelism does not support MLA with {cp_comm_type=}!"
if fp8 and fp8_meta is not None: if fp8 and fp8_meta is not None:
if fp8_meta["recipe"].fp8_dpa: if fp8_meta["recipe"].fp8_dpa:
assert ( assert (
softmax_type == "vanilla" softmax_type == "vanilla"
), "Context parallelism does not support {softmax_type=} with FP8 attention!" ), f"Context parallelism does not support {softmax_type=} with FP8 attention!"
assert ( assert (
softmax_type == "vanilla" or use_fused_attention softmax_type == "vanilla" or use_fused_attention
), "Context parallelism only supports {softmax_type=} with FusedAttention backend!" ), f"Context parallelism only supports {softmax_type=} with FusedAttention backend!"
assert ( assert (
softmax_type == "vanilla" or cp_comm_type == "a2a" softmax_type == "vanilla" or cp_comm_type == "a2a"
), "Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!" ), f"Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!"
assert ( if get_cudnn_version() < (9, 18, 0):
softmax_type == "vanilla" or qkv_format != "thd" assert softmax_type == "vanilla" or qkv_format != "thd", (
), "Context parallelism does not support {softmax_type=} with qkv_format = 'thd'!" f"Before cuDNN 9.18.0, context parallelism does not support {softmax_type=} with"
" qkv_format = 'thd'!"
)
args = [ args = [
is_training, is_training,
......
...@@ -716,22 +716,14 @@ def get_attention_backend( ...@@ -716,22 +716,14 @@ def get_attention_backend(
) )
use_unfused_attention = False use_unfused_attention = False
if qkv_format == "thd": if qkv_format == "thd":
if cudnn_version < (9, 18, 0):
logger.debug( logger.debug(
"Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type "Disabling FusedAttention for softmax_type = %s, qkv_format = thd and cuDNN"
) " version < 9.18",
use_fused_attention = False
logger.debug(
"Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd",
softmax_type, softmax_type,
) )
use_unfused_attention = False use_fused_attention = False
if context_parallel: if context_parallel:
logger.debug(
"Disabling UnfusedDotProductAttention for context parallelism with softmax_type"
" = %s",
softmax_type,
)
use_unfused_attention = False
if cp_comm_type != "a2a": if cp_comm_type != "a2a":
logger.debug( logger.debug(
"Disabling FusedAttention for context parallelism with softmax_type = %s and" "Disabling FusedAttention for context parallelism with softmax_type = %s and"
...@@ -1049,6 +1041,15 @@ def get_attention_backend( ...@@ -1049,6 +1041,15 @@ def get_attention_backend(
) )
use_flash_attention_2 = False use_flash_attention_2 = False
if use_fused_attention and deterministic: if use_fused_attention and deterministic:
if softmax_type != "vanilla":
logger.debug(
"Disabling FusedAttention for determinism reasons with softmax_type = %s. "
"Sink attention (off-by-one and learnable softmax) requires "
"NVTE_ALLOW_NONDETERMINISTIC_ALGO=1",
softmax_type,
)
use_fused_attention = False
fused_attention_backend = None
if fused_attention_backend == FusedAttnBackend["FP8"] and is_training: if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
logger.debug("Disabling FusedAttention for determinism reasons with FP8") logger.debug("Disabling FusedAttention for determinism reasons with FP8")
use_fused_attention = False use_fused_attention = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment