Unverified Commit 46bc37d0 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Disable THD tests on architectures lower than sm90 (#973)



* disable CP-THD tests for fused attn on <sm90
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 38524f71
...@@ -66,6 +66,8 @@ model_configs_fused_attn = { ...@@ -66,6 +66,8 @@ model_configs_fused_attn = {
@pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
def test_cp_with_fused_attention(dtype, model, qkv_format): def test_cp_with_fused_attention(dtype, model, qkv_format):
if qkv_format == "thd" and get_device_compute_capability() < (9, 0):
pytest.skip("THD format is only supported on sm90+.")
subprocess.run( subprocess.run(
get_bash_arguments( get_bash_arguments(
dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention" dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment