Unverified Commit 0650e517 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

fix: only enable flash_attn test on sm80 sm90 (#7289)

parent fc554105
...@@ -13,7 +13,7 @@ apply_rotary_emb = None ...@@ -13,7 +13,7 @@ apply_rotary_emb = None
def is_hopper(): def is_hopper():
# Only Hopper supports different V headdim # Only Hopper supports different V headdim
return torch.cuda.get_device_properties(0).major >= 9 return torch.cuda.get_device_properties(0).major == 9
def is_fa3_supported(device=None) -> bool: def is_fa3_supported(device=None) -> bool:
...@@ -451,7 +451,7 @@ def generate_qkv( ...@@ -451,7 +451,7 @@ def generate_qkv(
@pytest.mark.skipif( @pytest.mark.skipif(
not is_fa3_supported(), not is_fa3_supported(),
reason="flash_attn at sgl-kernel is only supported on sm90 and above", reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
) )
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -1009,6 +1009,10 @@ def _generate_block_kvcache( ...@@ -1009,6 +1009,10 @@ def _generate_block_kvcache(
return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks
@pytest.mark.skipif(
not is_fa3_supported(),
reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
)
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []) "dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])
......
...@@ -8,9 +8,8 @@ from sgl_kernel.sparse_flash_attn import ( ...@@ -8,9 +8,8 @@ from sgl_kernel.sparse_flash_attn import (
convert_vertical_slash_indexes, convert_vertical_slash_indexes,
convert_vertical_slash_indexes_mergehead, convert_vertical_slash_indexes_mergehead,
sparse_attn_func, sparse_attn_func,
sparse_attn_varlen_func,
) )
from test_flash_attention import construct_local_mask from test_flash_attention import construct_local_mask, is_fa3_supported
def ref_attn( def ref_attn(
...@@ -172,6 +171,10 @@ def ref_paged_attn( ...@@ -172,6 +171,10 @@ def ref_paged_attn(
return torch.cat(outputs, dim=0) return torch.cat(outputs, dim=0)
@pytest.mark.skipif(
not is_fa3_supported(),
reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
)
@pytest.mark.parametrize("batch_size", [1, 2]) @pytest.mark.parametrize("batch_size", [1, 2])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"seq_lens", "seq_lens",
...@@ -257,6 +260,10 @@ def test_sparse_attention( ...@@ -257,6 +260,10 @@ def test_sparse_attention(
# sparse attention utils # sparse attention utils
# origin # origin
@pytest.mark.skipif(
not is_fa3_supported(),
reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
)
@pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("causal", [True, False])
def test_convert_vertical_slash_indexes(causal): def test_convert_vertical_slash_indexes(causal):
# Prepare small, hand-checkable inputs # Prepare small, hand-checkable inputs
...@@ -311,6 +318,10 @@ def test_convert_vertical_slash_indexes(causal): ...@@ -311,6 +318,10 @@ def test_convert_vertical_slash_indexes(causal):
# mergehead # mergehead
@pytest.mark.skipif(
not is_fa3_supported(),
reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
)
@pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("causal", [True, False])
def test_convert_vertical_slash_indexes_mergehead(causal): def test_convert_vertical_slash_indexes_mergehead(causal):
# Prepare small, hand-checkable inputs for mergehead version # Prepare small, hand-checkable inputs for mergehead version
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment