Unverified Commit 3999442f authored by rasmith's avatar rasmith Committed by GitHub
Browse files

[CI/Build][AMD] Add check for flash_att_varlen_func to test_tree_attention.py (#29252)


Signed-off-by: default avatarRandall Smith <ransmith@amd.com>
Co-authored-by: default avatarRandall Smith <ransmith@amd.com>
parent 71362ffa
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import math import math
import pytest
import torch import torch
from tests.v1.attention.utils import ( from tests.v1.attention.utils import (
...@@ -11,9 +12,16 @@ from tests.v1.attention.utils import ( ...@@ -11,9 +12,16 @@ from tests.v1.attention.utils import (
try_get_attention_backend, try_get_attention_backend,
) )
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.utils.fa_utils import is_flash_attn_varlen_func_available
from vllm.config import ParallelConfig, SpeculativeConfig from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
if not is_flash_attn_varlen_func_available():
pytest.skip(
"This test requires flash_attn_varlen_func, but it's not available.",
allow_module_level=True,
)
class MockAttentionLayer(torch.nn.Module): class MockAttentionLayer(torch.nn.Module):
_q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda") _q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment