Unverified Commit 3d84ef90 authored by rasmith's avatar rasmith Committed by GitHub
Browse files

[CI/Build][AMD] Skip if flash_attn_varlen_func not available in test_aiter_flash_attn.py (#29043)


Signed-off-by: default avatarRandall Smith <ransmith@amd.com>
Co-authored-by: default avatarRandall Smith <ransmith@amd.com>
parent 4d01b642
...@@ -6,6 +6,7 @@ import pytest ...@@ -6,6 +6,7 @@ import pytest
import torch import torch
import vllm.v1.attention.backends.rocm_aiter_fa # noqa: F401 import vllm.v1.attention.backends.rocm_aiter_fa # noqa: F401
from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available
from vllm.platforms import current_platform from vllm.platforms import current_platform
NUM_HEADS = [(4, 4), (8, 2)] NUM_HEADS = [(4, 4), (8, 2)]
...@@ -100,6 +101,8 @@ def test_varlen_with_paged_kv( ...@@ -100,6 +101,8 @@ def test_varlen_with_paged_kv(
num_blocks: int, num_blocks: int,
q_dtype: torch.dtype | None, q_dtype: torch.dtype | None,
) -> None: ) -> None:
if not is_flash_attn_varlen_func_available():
pytest.skip("flash_attn_varlen_func required to run this test.")
torch.set_default_device("cuda") torch.set_default_device("cuda")
current_platform.seed_everything(0) current_platform.seed_everything(0)
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment