"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "c9d7f3f272c654cd46a571beb5fe97a7dd7602cf"
Unverified Commit a5c79876 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Disable determinism for sm100 (#2130)



* disable determinism for sm100+ and cudnn<9.14
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix remaining CI failures
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert some changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert more changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove sm100 from determinism table
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent c7761419
...@@ -122,13 +122,18 @@ if fp8_available: ...@@ -122,13 +122,18 @@ if fp8_available:
def is_fused_attn_available( def is_fused_attn_available(
config: ModelConfig, dtype: torch.dtype, qkv_layout="bshd_bshd_bshd", is_training=True config: ModelConfig,
dtype: torch.dtype,
qkv_layout="bshd_bshd_bshd",
is_training=True,
deterministic=False,
): ):
_, _, fused_attn_backends = get_available_attention_backends( _, _, fused_attn_backends = get_available_attention_backends(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
is_training=is_training, is_training=is_training,
deterministic=deterministic,
) )
return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends
...@@ -839,7 +844,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= ...@@ -839,7 +844,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path=
@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("model", ["126m"])
def test_gpt_checkpointing(dtype, bs, model): def test_gpt_checkpointing(dtype, bs, model):
config = model_configs[model] config = model_configs[model]
if not is_fused_attn_available(config, dtype): if not is_fused_attn_available(config, dtype, deterministic=True):
pytest.skip("No attention backend available.") pytest.skip("No attention backend available.")
outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False)
outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True)
...@@ -887,7 +892,9 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): ...@@ -887,7 +892,9 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean) @pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
config = model_configs[model] config = model_configs[model]
if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False): if not is_fused_attn_available(
config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True
):
pytest.skip("No attention backend available.") pytest.skip("No attention backend available.")
te_gpt = TransformerLayer( te_gpt = TransformerLayer(
...@@ -1000,7 +1007,9 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): ...@@ -1000,7 +1007,9 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
@pytest.mark.parametrize("mask_type", mask_types) @pytest.mark.parametrize("mask_type", mask_types)
def test_mha_accuracy(dtype, bs, model, mask_type): def test_mha_accuracy(dtype, bs, model, mask_type):
config = model_configs[model] config = model_configs[model]
if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False): if not is_fused_attn_available(
config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True
):
pytest.skip("No attention backend available.") pytest.skip("No attention backend available.")
te_mha = MultiheadAttention( te_mha = MultiheadAttention(
......
...@@ -266,8 +266,8 @@ def get_available_attention_backends( ...@@ -266,8 +266,8 @@ def get_available_attention_backends(
) )
( (
use_flash_attention, use_flash_attention,
use_fused_attention,
flash_attention_backend, flash_attention_backend,
use_fused_attention,
fused_attention_backend, fused_attention_backend,
use_unfused_attention, use_unfused_attention,
available_backends, available_backends,
......
...@@ -822,7 +822,7 @@ def get_attention_backend( ...@@ -822,7 +822,7 @@ def get_attention_backend(
# flash-attn >=2.4.1 | yes # flash-attn >=2.4.1 | yes
# FusedAttention | # FusedAttention |
# sub-backend 0 | yes # sub-backend 0 | yes
# sub-backend 1 | workspace optimization path and sm90+: yes; # sub-backend 1 | workspace optimization path and sm90: yes;
# | otherwise: no # | otherwise: no
# sub-backend 2 | no # sub-backend 2 | no
# UnfusedDotProductAttention | yes # UnfusedDotProductAttention | yes
...@@ -838,8 +838,9 @@ def get_attention_backend( ...@@ -838,8 +838,9 @@ def get_attention_backend(
use_flash_attention_2 = False use_flash_attention_2 = False
if use_fused_attention and deterministic: if use_fused_attention and deterministic:
if fused_attention_backend == FusedAttnBackend["FP8"] and is_training: if fused_attention_backend == FusedAttnBackend["FP8"] and is_training:
logger.debug("Disabling FusedAttention for determinism reasons") logger.debug("Disabling FusedAttention for determinism reasons with FP8")
use_fused_attention = False use_fused_attention = False
fused_attention_backend = None
if ( if (
fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
and is_training and is_training
...@@ -849,8 +850,13 @@ def get_attention_backend( ...@@ -849,8 +850,13 @@ def get_attention_backend(
or cudnn_version < (8, 9, 5) or cudnn_version < (8, 9, 5)
) )
): ):
logger.debug("Disabling FusedAttention for determinism reasons") logger.debug("Disabling FusedAttention for determinism reasons with post_scale_bias")
use_fused_attention = False
fused_attention_backend = None
if is_training and device_compute_capability >= (10, 0) and cudnn_version <= (9, 14, 0):
logger.debug("Disabling FusedAttention for determinism reasons on Blackwell")
use_fused_attention = False use_fused_attention = False
fused_attention_backend = None
# use_flash_attention may have been set above # use_flash_attention may have been set above
use_flash_attention_2 = use_flash_attention and use_flash_attention_2 use_flash_attention_2 = use_flash_attention and use_flash_attention_2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment