Unverified Commit 27fc168e authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[Common] Enable determinism for cuDNN >= 9.18.1 on Blackwell (#2584)



* update FE to 1.17
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add determinism flag
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add determinism to test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add determinism to qa/
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* move bias/dbias/versioning/dropout logic to C API
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update qa/L0_pytorch_unittest/test.sh

make .xml file specific to deterministic tests in qa/
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add determinism to Jax extension
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* add determinism to Jax tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update tests/jax/test_fused_attn.py

fix typo
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* Update transformer_engine/common/fused_attn/fused_attn.cpp

fix indentation
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix the AI fixes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix Jax extension call
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fixes based on comments
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix selection logic and fwd arg
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix version check in Jax test
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix pytorch CI failures
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix Jax CI failures
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix non-/determinism logic and CI
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix formatting
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/common/fused_attn/fused_attn.cpp

fix and/or logic
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* update to 9.18.1 for requirement
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* reduce Jax CI tests for determinism
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent dfdd3820
Subproject commit 0258951d4d512f4714eb1574496f4d57669b1b93
Subproject commit b372d39879d44c91a8d5b342022e74802b6a8da2
......@@ -29,6 +29,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
mkdir -p "$XML_LOG_DIR"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_fused_attn_with_determinism.xml $TE_PATH/tests/jax/test_fused_attn.py -k "TestFusedAttnWithDeterminism" || test_fail "tests/jax/test_fused_attn.py"
pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist"
......
......@@ -45,6 +45,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_e
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
......
......@@ -2,6 +2,7 @@
#
# See LICENSE for license information.
"""Tests for fused attention"""
import os
from enum import Enum, auto
from dataclasses import dataclass, field
from functools import partial
......@@ -49,6 +50,9 @@ from transformer_engine_jax import (
from distributed_test_base import assert_equal_collectives
from utils import assert_allclose, print_debug_tensor_stats
# Get determinism
_deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
@pytest.fixture(autouse=True, scope="module")
def init():
......@@ -413,14 +417,24 @@ class FusedAttnRunner:
pytest.skip(
"seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN"
)
# TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support
if (
get_device_compute_capability(0) >= 100
and self.dropout_prob == 0.1
and self.attn_bias_type is not AttnBiasType.NO_BIAS
if get_device_compute_capability(0) >= 100 and self.is_training:
if FusedAttnHelper.is_non_deterministic_allowed() and (
(self.dropout_prob != 0.0 and self.attn_bias_type != AttnBiasType.NO_BIAS)
or get_cudnn_version() < 90700
):
pytest.skip(
"For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with"
" dropout"
)
if not FusedAttnHelper.is_non_deterministic_allowed() and (
self.dropout_prob != 0.0
or self.attn_bias_type != AttnBiasType.NO_BIAS
or get_cudnn_version() < 91801
):
pytest.skip(
"For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
"For sm100+, deterministic bprop (cuDNN 9.18.1+) does not support bias or"
" dropout"
)
# Test the MLA case where head dims for qk differ from head dims for v, only if the tensors
# are provided in BSHD_BSHD_BSHD or THD_THD_THD formats
......@@ -1269,6 +1283,7 @@ class FusedAttnRunner:
pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"),
],
)
@pytest.mark.skipif(_deterministic, reason="Test non-determinism only")
class TestFusedAttn:
"""
Fused attention tester
......@@ -1392,3 +1407,182 @@ class TestFusedAttn:
seq_desc_format,
)
runner.test_backward()
@pytest.mark.parametrize(
"attn_mask_type",
[
pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
pytest.param(
AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, id="PADDING_CAUSAL_BOTTOM_RIGHT"
),
],
)
@pytest.mark.parametrize(
"softmax_type",
[
pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"),
],
)
@pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype, qkv_layout",
[
# large data size + fp16 + cross attn + gqa + diff hidden v dim + qkv separate
pytest.param(
2,
1024,
2048,
12,
6,
128,
64,
jnp.bfloat16,
QKVLayout.BSHD_BSHD_BSHD,
id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-SEPARATE",
),
pytest.param(
2,
1024,
2048,
12,
6,
128,
64,
jnp.bfloat16,
QKVLayout.THD_THD_THD,
id="2-1024-2048-12-6-128-64-BF16-CROSS-GQA-RAGGED_SEPARATE",
),
],
)
@pytest.mark.parametrize(
"dropout_prob",
[
pytest.param(0.0, id="DROP_0.0"),
],
)
@pytest.mark.parametrize(
"swa",
[
pytest.param(False, id="NO_SWA"),
],
)
@pytest.mark.parametrize(
"seq_desc_format",
[
pytest.param(SeqDescFormat.Seqlens, id="Seqlens"),
],
)
@pytest.mark.skipif(not _deterministic, reason="Test determinism only")
class TestFusedAttnWithDeterminism:
"""
Fused attention tester with determinism
"""
@staticmethod
@pytest.mark.parametrize(
"is_training",
[
pytest.param(True, id="TRAINING"),
],
)
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
],
)
def _test_forward(
b,
s_q,
s_kv,
h_q,
h_kv,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
qkv_layout,
bias_shape,
swa,
seq_desc_format,
):
"""
Test forward with parameterized configs
This test is not intended to run automatically during CI as it is time-consuming
It is kept for development and debugging
"""
TestFusedAttn._test_forward(
b,
s_q,
s_kv,
h_q,
h_kv,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
is_training,
qkv_layout,
bias_shape,
swa,
seq_desc_format,
)
@staticmethod
@pytest.mark.parametrize(
"attn_bias_type, bias_shape",
[
pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"),
],
)
def test_backward(
b,
s_q,
s_kv,
h_q,
h_kv,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
qkv_layout,
bias_shape,
swa,
seq_desc_format,
):
"""
Test backward with parameterized configs
"""
TestFusedAttn.test_backward(
b,
s_q,
s_kv,
h_q,
h_kv,
d_qk,
d_v,
attn_bias_type,
attn_mask_type,
softmax_type,
dropout_prob,
dtype,
qkv_layout,
bias_shape,
swa,
seq_desc_format,
)
......@@ -72,6 +72,14 @@ if fp8_available and (device_compute_capability < (9, 0) or device_compute_capab
f" sm{device_compute_capability[0] * 10 + device_compute_capability[1]}"
)
# Get determinism
_deterministic = (
not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
or torch.are_deterministic_algorithms_enabled()
)
# Reset RNG seed and states
seed = 1234
reset_rng_states()
......@@ -160,6 +168,7 @@ def test_dot_product_attention(
qkv_layout=qkv_layout,
pad_between_seqs=pad_between_seqs,
is_training=is_training,
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported:
......@@ -170,6 +179,7 @@ def test_dot_product_attention(
qkv_layout=qkv_layout,
pad_between_seqs=pad_between_seqs,
is_training=is_training,
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
......@@ -886,11 +896,14 @@ def _run_dot_product_attention(
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
if backend == "UnfusedDotProductAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
# Create seqlens
......@@ -1292,6 +1305,7 @@ def test_transformer_layer(
qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd")
),
is_training=is_training,
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported:
......@@ -1305,6 +1319,7 @@ def test_transformer_layer(
else qkv_format.replace("hd", "3hd")
),
is_training=is_training,
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
......@@ -1432,10 +1447,13 @@ def _run_transformer_layer(
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
if backend == "UnfusedDotProductAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
# Create input tensor
......@@ -1629,6 +1647,7 @@ def test_dpa_fp8_extra_state(model, dtype):
qkv_dtype=torch.float8_e4m3fn,
qkv_layout="sb3hd",
is_training=is_training,
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported and not flash_attn_supported:
......@@ -1819,6 +1838,7 @@ def test_mha_fp8_vs_f16(
fp8=True,
fp8_meta=fp8_meta,
is_training=is_training,
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends
if flash_attn_supported + fused_attn_supported_fp8 < 1:
......@@ -1830,6 +1850,7 @@ def test_mha_fp8_vs_f16(
qkv_dtype=dtype,
qkv_layout=qkv_format.replace("hd", "h3d"),
is_training=is_training,
deterministic=_deterministic,
)
_, fused_attn_supported_f16, _ = available_backends
if not fused_attn_supported_f16:
......@@ -1838,6 +1859,7 @@ def test_mha_fp8_vs_f16(
if flash_attn_supported:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
......@@ -1847,6 +1869,7 @@ def test_mha_fp8_vs_f16(
if fused_attn_supported_fp8:
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
......@@ -1856,6 +1879,7 @@ def test_mha_fp8_vs_f16(
if fused_attn_supported_f16:
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
......@@ -2068,6 +2092,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
fp8=True,
fp8_meta=fp8_meta,
is_training=is_training,
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if flash_attn_supported + fused_attn_supported < 1:
......@@ -2078,6 +2103,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
qkv_dtype=dtype,
qkv_layout=qkv_layout,
is_training=is_training,
deterministic=_deterministic,
)
_, fused_attn_supported, _ = available_backends
if not fused_attn_supported:
......@@ -2088,6 +2114,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
if flash_attn_supported:
os.environ["NVTE_FLASH_ATTN"] = "1"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)")
flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
......@@ -2097,6 +2124,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
if unfused_attn_supported:
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)")
unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
......@@ -2105,6 +2133,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
_attention_backends["backend_selection_requires_update"] = True
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
......@@ -2113,6 +2142,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scal
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if config.dropout_p == 0.0:
# test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)")
......@@ -2367,13 +2397,16 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
qkv_dtype=torch.float8_e4m3fn,
qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd",
is_training=is_training,
deterministic=_deterministic,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not (fused_attn_backends and unfused_attn_supported):
pytest.skip("Not enough backends to run this test with.")
fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention")
unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(
dtype, config, "UnfusedDotProductAttention"
)
atol = 5e-1
rtol = 5e-1
......@@ -2406,10 +2439,13 @@ def _run_custom_mha_fp8(dtype, config, backend):
reset_rng_states()
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
if backend == "UnfusedDotProductAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
inp = 0.0001 * torch.randint(
......@@ -2460,10 +2496,13 @@ def _run_ref_mha_f16(dtype, config, backend):
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
os.environ["NVTE_UNFUSED_ATTN"] = "0"
if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
if backend == "UnfusedDotProductAttention":
os.environ["NVTE_UNFUSED_ATTN"] = "1"
_attention_backends["backend_selection_requires_update"] = True
inp = torch.load("qkv.pt").to(device="cuda")
......
......@@ -206,7 +206,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right, bool return_max_logit, bool cuda_graph) {
int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) {
using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device();
......@@ -440,7 +440,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// 9.13.1+: vanilla, off-by-one, learnable
(cudnn_runtime_version >= 91301 ||
(cudnn_runtime_version < 91301 &&
softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX))) {
softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) &&
// determinism on Blackwell
// pre-9.18.1: fwd: deterministic; bwd: non-deterministic
// 9.18.1+: fwd: deterministic; bwd: non-deterministic/deterministic
(sm_arch_ < 100 ||
(sm_arch_ >= 100 && (!is_training ||
(is_training && !deterministic &&
(dropout == 0.0 || bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)) ||
(is_training && deterministic && cudnn_runtime_version >= 91801 &&
dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS))))) {
flag_arb = true;
}
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) {
......@@ -553,7 +562,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit,
cuda_graph);
cuda_graph, false);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -595,7 +604,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. "
"\n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
......@@ -669,7 +679,8 @@ void nvte_fused_attn_bwd_qkvpacked(
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h,
max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph);
max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph,
deterministic);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -855,7 +866,7 @@ void nvte_fused_attn_fwd_kvpacked(
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right,
return_max_logit, cuda_graph);
return_max_logit, cuda_graph, false);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -897,7 +908,8 @@ void nvte_fused_attn_fwd_kvpacked(
input_page_table_v, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. "
"\n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
......@@ -982,10 +994,10 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type,
softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
d, window_size_left, window_size_right, false, cuda_graph);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, false,
cuda_graph, deterministic);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -1166,7 +1178,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right,
return_max_logit, cuda_graph);
return_max_logit, cuda_graph, false);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -1189,7 +1201,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
input_page_table_v, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. "
"\n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
......@@ -1262,7 +1275,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false,
cuda_graph);
cuda_graph, deterministic);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......
......@@ -208,13 +208,14 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] deterministic Whether determinism is required or not.
*/
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right, bool return_max_logit, bool cuda_graph);
int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic);
/*! \brief Compute dot product attention with packed QKV input.
*
......
......@@ -144,6 +144,7 @@ class FusedAttnHelper:
self.head_dim_v,
self.window_size[0],
self.window_size[1],
not self.is_non_deterministic_allowed(),
)
@staticmethod
......@@ -3563,13 +3564,21 @@ def fused_attn_bwd(
softmax_offset, (None, HEAD_AXES, None, None)
)
# TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on
# sm100+
compute_capabilities = get_all_device_compute_capability()
if any(x >= 100 for x in compute_capabilities):
assert not (
attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0
), "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported"
if any(x >= 100 for x in compute_capabilities) and is_training:
assert (
FusedAttnHelper.is_non_deterministic_allowed()
and get_cudnn_version() >= (9, 7, 0)
and (attn_bias_type == AttnBiasType.NO_BIAS or dropout_probability == 0.0)
) or (
not FusedAttnHelper.is_non_deterministic_allowed()
and get_cudnn_version() >= (9, 18, 1)
and attn_bias_type == AttnBiasType.NO_BIAS
and dropout_probability == 0.0
), (
"For sm100+, non-deterministic bprop (cuDNN 9.7+) does not support bias with dropout,"
" and deterministic bprop (cuDNN 9.18.1+) does not support bias or dropout"
)
fused_config = _FusedAttnConfig(
attn_bias_type=attn_bias_type,
......
......@@ -113,7 +113,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left,
int64_t window_size_right);
int64_t window_size_right, bool deterministic);
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
......
......@@ -16,12 +16,12 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen,
size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left,
int64_t window_size_right) {
int64_t window_size_right, bool deterministic) {
auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
false, false);
false, false, deterministic);
return backend;
}
......@@ -266,7 +266,7 @@ static void FusedAttnForwardImpl(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
false, false);
false, false, deterministic);
nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */
......@@ -522,7 +522,7 @@ static void FusedAttnBackwardImpl(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right,
false, false);
false, false, deterministic);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
softmax_aux, rng_state, bias, softmax_offset);
......
......@@ -994,6 +994,7 @@ def get_attention_backend(
window_size[1],
return_max_logit,
cuda_graph,
deterministic,
)
if fused_attention_backend == FusedAttnBackend["No_Backend"]:
logger.debug("Disabling FusedAttention as no backend supports the provided input")
......@@ -1064,10 +1065,6 @@ def get_attention_backend(
logger.debug("Disabling FusedAttention for determinism reasons with post_scale_bias")
use_fused_attention = False
fused_attention_backend = None
if is_training and device_compute_capability >= (10, 0):
logger.debug("Disabling FusedAttention for determinism reasons on Blackwell")
use_fused_attention = False
fused_attention_backend = None
# use_flash_attention may have been set above
use_flash_attention_2 = use_flash_attention and use_flash_attention_2
......
......@@ -81,7 +81,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right, bool return_max_logit, bool cuda_graph);
int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic);
std::vector<py::object> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
......
......@@ -45,12 +45,12 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right, bool return_max_logit, bool cuda_graph) {
int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) {
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups,
max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right,
return_max_logit, cuda_graph);
return_max_logit, cuda_graph, deterministic);
return fused_attention_backend;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment