"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "ce2e8bd12edfe10647bec8f54fedc394d6287b58"
Unverified Commit a5c79876 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Disable determinism for sm100 (#2130)



* disable determinism for sm100+ and cudnn<9.14
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix remaining CI failures
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert some changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert more changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove sm100 from determinism table
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent c7761419
......@@ -122,13 +122,18 @@ if fp8_available:
def is_fused_attn_available(
config: ModelConfig, dtype: torch.dtype, qkv_layout="bshd_bshd_bshd", is_training=True
config: ModelConfig,
dtype: torch.dtype,
qkv_layout="bshd_bshd_bshd",
is_training=True,
deterministic=False,
):
_, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
is_training=is_training,
deterministic=deterministic,
)
return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends
......@@ -839,7 +844,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
@pytest.mark.parametrize("model", ["126m"])
def test_gpt_checkpointing(dtype, bs, model):
config = model_configs[model]
if not is_fused_attn_available(config, dtype):
if not is_fused_attn_available(config, dtype, deterministic=True):
pytest.skip("No attention backend available.")
outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
......@@ -887,7 +892,9 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
config = model_configs[model]
if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False):
if not is_fused_attn_available(
config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True
):
pytest.skip("No attention backend available.")
te_gpt = TransformerLayer(
......@@ -1000,7 +1007,9 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
@pytest.mark.parametrize("mask_type", mask_types)
def test_mha_accuracy(dtype, bs, model, mask_type):
config = model_configs[model]
if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False):
if not is_fused_attn_available(
config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True
):
pytest.skip("No attention backend available.")
te_mha = MultiheadAttention(
......
......@@ -266,8 +266,8 @@ def get_available_attention_backends(
)
(
use_flash_attention,
use_fused_attention,
flash_attention_backend,
use_fused_attention,
fused_attention_backend,
use_unfused_attention,
available_backends,
......
......@@ -822,7 +822,7 @@ def get_attention_backend(
# flash-attn >=2.4.1 | yes
# FusedAttention |
# sub-backend 0 | yes
# sub-backend 1 | workspace optimization path and sm90+: yes;
# sub-backend 1 | workspace optimization path and sm90: yes;
# | otherwise: no
# sub-backend 2 | no
# UnfusedDotProductAttention | yes
......@@ -838,8 +838,9 @@ def get_attention_backend(
use_flash_attention_2 = False
if use_fused_attention and deterministic:
if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
logger.debug("Disabling FusedAttention for determinism reasons")
logger.debug("Disabling FusedAttention for determinism reasons with FP8")
use_fused_attention = False
fused_attention_backend = None
if (
fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
and is_training
......@@ -849,8 +850,13 @@ def get_attention_backend(
or cudnn_version < (8, 9, 5)
)
):
logger.debug("Disabling FusedAttention for determinism reasons")
logger.debug("Disabling FusedAttention for determinism reasons with post_scale_bias")
use_fused_attention = False
fused_attention_backend = None
if is_training and device_compute_capability >= (10, 0) and cudnn_version <= (9, 14, 0):
logger.debug("Disabling FusedAttention for determinism reasons on Blackwell")
use_fused_attention = False
fused_attention_backend = None
# use_flash_attention may have been set above
use_flash_attention_2 = use_flash_attention and use_flash_attention_2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment