Unverified Commit b12f4a98 authored by rasmith's avatar rasmith Committed by GitHub
Browse files

[CI/Build][AMD] Use ROCM_ATTN instead of FLASH_ATTN test for...


[CI/Build][AMD] Use ROCM_ATTN instead of FLASH_ATTN test for test_register_kv_caches for ROCm and update test for TRITON_ATTN (#29985)
Signed-off-by: default avatarRandall Smith <ransmith@amd.com>
Co-authored-by: default avatarRandall Smith <ransmith@amd.com>
Co-authored-by: default avatarTJian <tunjian.tan@embeddedllm.com>
parent 40a046cd
...@@ -41,6 +41,7 @@ from vllm.distributed.kv_transfer.kv_transfer_state import ( ...@@ -41,6 +41,7 @@ from vllm.distributed.kv_transfer.kv_transfer_state import (
has_kv_transfer_group, has_kv_transfer_group,
) )
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.platforms import current_platform
from vllm.platforms.interface import Platform from vllm.platforms.interface import Platform
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
...@@ -1111,7 +1112,26 @@ def _run_abort_timeout_test(llm: LLM, timeout: int): ...@@ -1111,7 +1112,26 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
llm.llm_engine.engine_core.shutdown() llm.llm_engine.engine_core.shutdown()
@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "TRITON_ATTN"]) @pytest.mark.parametrize(
"attn_backend",
[
pytest.param(
"FLASH_ATTN",
marks=pytest.mark.skipif(
current_platform.is_rocm(),
reason="Attention backend FLASH_ATTN is not supported on ROCm",
),
),
pytest.param(
"ROCM_ATTN",
marks=pytest.mark.skipif(
not current_platform.is_rocm(),
reason="Attention backend ROCM_ATTN is only supported on ROCm",
),
),
"TRITON_ATTN",
],
)
def test_register_kv_caches(dist_init, attn_backend, monkeypatch): def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
""" """
Test that register_kv_caches() properly calls nixl_wrapper methods with Test that register_kv_caches() properly calls nixl_wrapper methods with
...@@ -1133,6 +1153,10 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch): ...@@ -1133,6 +1153,10 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
backend_cls = FlashAttentionBackend backend_cls = FlashAttentionBackend
elif attn_backend == "ROCM_ATTN":
from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend
backend_cls = RocmAttentionBackend
else: # TRITON_ATTN else: # TRITON_ATTN
from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend
...@@ -1151,6 +1175,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch): ...@@ -1151,6 +1175,7 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
} }
# Store tensor info for validation # Store tensor info for validation
test_shape = backend_cls.get_kv_cache_shape( test_shape = backend_cls.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
) )
...@@ -1175,17 +1200,18 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch): ...@@ -1175,17 +1200,18 @@ def test_register_kv_caches(dist_init, attn_backend, monkeypatch):
] ]
expected_num_entries = 4 expected_num_entries = 4
nixl_module = "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector"
with ( with (
patch( patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper,
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper" patch(f"{nixl_module}.threading.Event"),
) as mock_nixl_wrapper, patch(f"{nixl_module}.threading.Thread") as mock_thread,
patch( patch(f"{nixl_module}.get_attn_backend") as mock_get_attn_backend,
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event" ):
), # Ensure get_attn_backend returns the correct value due to
patch( # _cached_get_attn_backend returning the backend from previous
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread" # test run if not mocking.
) as mock_thread, mock_get_attn_backend.return_value = backend_cls
): # noqa: E501
# Create connector # Create connector
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker( connector.connector_worker = FakeNixlConnectorWorker(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment