Unverified Commit 071b9508 authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

[PyTorch] Skip context parallel tests on architectures below sm80 (#799)



restrict context parallel tests to sm80+ as fused/flash attn backends require sm80+
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 7f1d604f
......@@ -10,6 +10,7 @@ from test_fused_attn import (
_is_flash_attention_2_available,
_cudnn_version,
)
from transformer_engine.pytorch.utils import get_device_compute_capability
model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
......@@ -29,6 +30,7 @@ def get_bash_arguments(**kwargs):
return args
@pytest.mark.skipif(not _is_flash_attention_2_available(), reason="Flash-attn 2.0+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ['bf16', 'fp16'])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd'])
......@@ -56,6 +58,7 @@ model_configs_fused_attn = {
}
@pytest.mark.skipif(_cudnn_version() < (8,9,7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
@pytest.mark.parametrize("dtype", ['bf16', 'fp16'])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment