Unverified Commit 8afcd0f6 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix] Fix broken kernel test due to missing rename for v1 Triton backend (#15282)


Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
parent 91ca929d
...@@ -49,7 +49,7 @@ def test_env( ...@@ -49,7 +49,7 @@ def test_env(
RocmPlatform()): RocmPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16, backend = get_attn_backend(16, torch.float16, torch.float16,
16, False) 16, False)
EXPECTED = "ROCM_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH" EXPECTED = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
assert backend.get_name() == EXPECTED assert backend.get_name() == EXPECTED
elif device == "openvino": elif device == "openvino":
with patch("vllm.attention.selector.current_platform", with patch("vllm.attention.selector.current_platform",
......
...@@ -26,7 +26,7 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): ...@@ -26,7 +26,7 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
# Test standard ROCm attention # Test standard ROCm attention
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False) backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
assert (backend.get_name() == "ROCM_FLASH" assert (backend.get_name() == "ROCM_FLASH"
or backend.get_name() == "ROCM_ATTN_VLLM_V1") or backend.get_name() == "TRITON_ATTN_VLLM_V1")
# mla test for deepseek related # mla test for deepseek related
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment