Unverified Commit d63b9696 authored by Alexei-V-Ivanov-AMD's avatar Alexei-V-Ivanov-AMD Committed by GitHub
Browse files

[CI/ROCm] Fixing "V1 Test attention (H100)" test group. (#31187)


Signed-off-by: default avatarDCCS-4560 <alivanov@chi-mi325x-pod1-108.ord.vultr.cpe.ice.amd.com>
Signed-off-by: <>
Co-authored-by: default avatarDCCS-4560 <alivanov@chi-mi325x-pod1-108.ord.vultr.cpe.ice.amd.com>
Co-authored-by: default avatarroot <root@chi-mi325x-pod1-108.ord.vultr.cpe.ice.amd.com>
parent 56f51625
...@@ -557,9 +557,21 @@ def test_causal_backend_correctness( ...@@ -557,9 +557,21 @@ def test_causal_backend_correctness(
if is_torch_equal_or_newer("2.9.0.dev0") if is_torch_equal_or_newer("2.9.0.dev0")
else [] else []
) )
if current_platform.is_rocm():
SMALL_BLOCK_BACKENDS = [
x
for x in BACKENDS_TO_TEST
if (
x not in LARGE_BLOCK_BACKENDS
and x is not AttentionBackendEnum.FLASH_ATTN
)
]
else:
SMALL_BLOCK_BACKENDS = [ SMALL_BLOCK_BACKENDS = [
x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
] ]
_test_backend_correctness( _test_backend_correctness(
batch_spec, batch_spec,
model, model,
...@@ -580,12 +592,20 @@ def test_causal_backend_correctness( ...@@ -580,12 +592,20 @@ def test_causal_backend_correctness(
) )
SLIDING_WINDOW_BACKENDS_TO_TEST = [ if current_platform.is_rocm():
# FLASH_ATTN is not supported on ROCm
SLIDING_WINDOW_BACKENDS_TO_TEST = [
AttentionBackendEnum.FLEX_ATTENTION,
AttentionBackendEnum.TRITON_ATTN,
"FLEX_ATTENTION_SLOW",
]
else:
SLIDING_WINDOW_BACKENDS_TO_TEST = [
AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.FLEX_ATTENTION, AttentionBackendEnum.FLEX_ATTENTION,
AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.TRITON_ATTN,
"FLEX_ATTENTION_SLOW", "FLEX_ATTENTION_SLOW",
] ]
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -8,6 +8,7 @@ import pytest ...@@ -8,6 +8,7 @@ import pytest
import torch import torch
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import AttentionSelectorConfig
from vllm.platforms import current_platform from vllm.platforms import current_platform
# ROCm-specific attention backend selection tests # ROCm-specific attention backend selection tests
...@@ -144,8 +145,7 @@ def test_standard_attention_backend_selection( ...@@ -144,8 +145,7 @@ def test_standard_attention_backend_selection(
# Get the backend class path # Get the backend class path
from vllm.platforms.rocm import RocmPlatform from vllm.platforms.rocm import RocmPlatform
backend_path = RocmPlatform.get_attn_backend_cls( attn_selector_config = AttentionSelectorConfig(
selected_backend=backend_enum,
head_size=128, head_size=128,
dtype=torch.float16, dtype=torch.float16,
kv_cache_dtype="auto", kv_cache_dtype="auto",
...@@ -154,6 +154,11 @@ def test_standard_attention_backend_selection( ...@@ -154,6 +154,11 @@ def test_standard_attention_backend_selection(
has_sink=False, has_sink=False,
use_sparse=False, use_sparse=False,
) )
backend_path = RocmPlatform.get_attn_backend_cls(
selected_backend=backend_enum, attn_selector_config=attn_selector_config
)
assert backend_path == expected_backend_path assert backend_path == expected_backend_path
...@@ -267,8 +272,16 @@ def test_mla_backend_selection( ...@@ -267,8 +272,16 @@ def test_mla_backend_selection(
if should_raise: if should_raise:
with pytest.raises(ValueError): with pytest.raises(ValueError):
RocmPlatform.get_attn_backend_cls( attn_selector_config = AttentionSelectorConfig(
selected_backend=backend_enum, head_size=128,
dtype=torch.float16,
kv_cache_dtype="auto",
block_size=block_size,
use_mla=True,
has_sink=False,
use_sparse=False,
)
attn_selector_config = AttentionSelectorConfig(
head_size=128, head_size=128,
dtype=torch.float16, dtype=torch.float16,
kv_cache_dtype="auto", kv_cache_dtype="auto",
...@@ -277,9 +290,13 @@ def test_mla_backend_selection( ...@@ -277,9 +290,13 @@ def test_mla_backend_selection(
has_sink=False, has_sink=False,
use_sparse=False, use_sparse=False,
) )
else:
backend_path = RocmPlatform.get_attn_backend_cls( backend_path = RocmPlatform.get_attn_backend_cls(
selected_backend=backend_enum, selected_backend=backend_enum,
attn_selector_config=attn_selector_config,
)
else:
attn_selector_config = AttentionSelectorConfig(
head_size=128, head_size=128,
dtype=torch.float16, dtype=torch.float16,
kv_cache_dtype="auto", kv_cache_dtype="auto",
...@@ -288,6 +305,11 @@ def test_mla_backend_selection( ...@@ -288,6 +305,11 @@ def test_mla_backend_selection(
has_sink=False, has_sink=False,
use_sparse=False, use_sparse=False,
) )
backend_path = RocmPlatform.get_attn_backend_cls(
selected_backend=backend_enum, attn_selector_config=attn_selector_config
)
assert backend_path == expected_backend_path assert backend_path == expected_backend_path
...@@ -303,8 +325,7 @@ def test_aiter_fa_requires_gfx9(mock_vllm_config): ...@@ -303,8 +325,7 @@ def test_aiter_fa_requires_gfx9(mock_vllm_config):
match="only supported on gfx9", match="only supported on gfx9",
), ),
): ):
RocmPlatform.get_attn_backend_cls( attn_selector_config = AttentionSelectorConfig(
selected_backend=AttentionBackendEnum.ROCM_AITER_FA,
head_size=128, head_size=128,
dtype=torch.float16, dtype=torch.float16,
kv_cache_dtype="auto", kv_cache_dtype="auto",
...@@ -314,6 +335,11 @@ def test_aiter_fa_requires_gfx9(mock_vllm_config): ...@@ -314,6 +335,11 @@ def test_aiter_fa_requires_gfx9(mock_vllm_config):
use_sparse=False, use_sparse=False,
) )
RocmPlatform.get_attn_backend_cls(
selected_backend=AttentionBackendEnum.ROCM_AITER_FA,
attn_selector_config=attn_selector_config,
)
def test_sparse_not_supported(mock_vllm_config): def test_sparse_not_supported(mock_vllm_config):
"""Test that sparse attention is not supported on ROCm.""" """Test that sparse attention is not supported on ROCm."""
...@@ -322,8 +348,7 @@ def test_sparse_not_supported(mock_vllm_config): ...@@ -322,8 +348,7 @@ def test_sparse_not_supported(mock_vllm_config):
with pytest.raises( with pytest.raises(
AssertionError, match="Sparse MLA backend on ROCm only supports block size 1" AssertionError, match="Sparse MLA backend on ROCm only supports block size 1"
): ):
RocmPlatform.get_attn_backend_cls( attn_selector_config = AttentionSelectorConfig(
selected_backend=None,
head_size=128, head_size=128,
dtype=torch.float16, dtype=torch.float16,
kv_cache_dtype="auto", kv_cache_dtype="auto",
...@@ -332,3 +357,7 @@ def test_sparse_not_supported(mock_vllm_config): ...@@ -332,3 +357,7 @@ def test_sparse_not_supported(mock_vllm_config):
has_sink=False, has_sink=False,
use_sparse=True, use_sparse=True,
) )
RocmPlatform.get_attn_backend_cls(
selected_backend=None, attn_selector_config=attn_selector_config
)
...@@ -24,6 +24,7 @@ from vllm import _custom_ops as ops ...@@ -24,6 +24,7 @@ from vllm import _custom_ops as ops
from vllm.attention.ops import flashmla from vllm.attention.ops import flashmla
from vllm.config import set_current_vllm_config from vllm.config import set_current_vllm_config
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mla.flashmla_sparse import ( from vllm.v1.attention.backends.mla.flashmla_sparse import (
FlashMLASparseBackend, FlashMLASparseBackend,
...@@ -125,6 +126,9 @@ def _quantize_dequantize_fp8_ds_mla( ...@@ -125,6 +126,9 @@ def _quantize_dequantize_fp8_ds_mla(
def test_sparse_backend_decode_correctness( def test_sparse_backend_decode_correctness(
dist_init, batch_name, kv_cache_dtype, tensor_parallel_size, workspace_init dist_init, batch_name, kv_cache_dtype, tensor_parallel_size, workspace_init
): ):
if current_platform.is_rocm():
pytest.skip("ROCm does not support fp8_ds_mla data type for kv cache.")
if not torch.cuda.is_available(): if not torch.cuda.is_available():
pytest.skip("CUDA is required for sparse MLA decode test") pytest.skip("CUDA is required for sparse MLA decode test")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment