Unverified Commit 3ffa5200 authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[ROCm][CI] Guard CudaPlatform/RocmPlatform imports to fix test collection on...


[ROCm][CI] Guard CudaPlatform/RocmPlatform imports to fix test collection on cross-platform builds (#37617)
Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent 87bd9189
...@@ -14,8 +14,19 @@ from vllm.config import ( ...@@ -14,8 +14,19 @@ from vllm.config import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform # CudaPlatform and RocmPlatform import their respective compiled C extensions
# at module level, raising ModuleNotFoundError on incompatible builds.
try:
from vllm.platforms.cuda import CudaPlatform
except (ImportError, ModuleNotFoundError):
CudaPlatform = None
try:
from vllm.platforms.rocm import RocmPlatform
except (ImportError, ModuleNotFoundError):
RocmPlatform = None
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import _cached_get_attn_backend, get_attn_backend from vllm.v1.attention.selector import _cached_get_attn_backend, get_attn_backend
...@@ -101,6 +112,8 @@ def test_backend_selection( ...@@ -101,6 +112,8 @@ def test_backend_selection(
assert backend.get_name() == "CPU_ATTN" assert backend.get_name() == "CPU_ATTN"
elif device == "hip": elif device == "hip":
if RocmPlatform is None:
pytest.skip("RocmPlatform not available")
with patch("vllm.platforms.current_platform", RocmPlatform()): with patch("vllm.platforms.current_platform", RocmPlatform()):
if use_mla: if use_mla:
# ROCm MLA backend logic: # ROCm MLA backend logic:
...@@ -126,6 +139,8 @@ def test_backend_selection( ...@@ -126,6 +139,8 @@ def test_backend_selection(
assert backend.get_name() == expected assert backend.get_name() == expected
elif device == "cuda": elif device == "cuda":
if CudaPlatform is None:
pytest.skip("CudaPlatform not available")
with patch("vllm.platforms.current_platform", CudaPlatform()): with patch("vllm.platforms.current_platform", CudaPlatform()):
capability = torch.cuda.get_device_capability() capability = torch.cuda.get_device_capability()
if use_mla: if use_mla:
...@@ -214,7 +229,7 @@ def test_backend_selection( ...@@ -214,7 +229,7 @@ def test_backend_selection(
assert backend.get_name() == expected assert backend.get_name() == expected
@pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("device", ["cpu", "cuda", "hip"])
def test_fp32_fallback(device: str): def test_fp32_fallback(device: str):
"""Test attention backend selection with fp32.""" """Test attention backend selection with fp32."""
# Use default config (no backend specified) # Use default config (no backend specified)
...@@ -227,10 +242,25 @@ def test_fp32_fallback(device: str): ...@@ -227,10 +242,25 @@ def test_fp32_fallback(device: str):
assert backend.get_name() == "CPU_ATTN" assert backend.get_name() == "CPU_ATTN"
elif device == "cuda": elif device == "cuda":
if CudaPlatform is None:
pytest.skip("CudaPlatform not available")
with patch("vllm.platforms.current_platform", CudaPlatform()): with patch("vllm.platforms.current_platform", CudaPlatform()):
backend = get_attn_backend(16, torch.float32, None) backend = get_attn_backend(16, torch.float32, None)
assert backend.get_name() == "FLEX_ATTENTION" assert backend.get_name() == "FLEX_ATTENTION"
elif device == "hip":
if RocmPlatform is None:
pytest.skip("RocmPlatform not available")
# ROCm backends do not support head_size=16 (minimum is 32).
# No known HuggingFace transformer model uses head_size=16.
# Revisit if a real model with this head size is identified
# and accuracy-tested.
with (
patch("vllm.platforms.current_platform", RocmPlatform()),
pytest.raises(ValueError, match="No valid attention backend"),
):
get_attn_backend(16, torch.float32, None)
def test_flash_attn(monkeypatch: pytest.MonkeyPatch): def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
"""Test FlashAttn validation.""" """Test FlashAttn validation."""
...@@ -367,6 +397,8 @@ def test_per_head_quant_scales_backend_selection( ...@@ -367,6 +397,8 @@ def test_per_head_quant_scales_backend_selection(
attention_config=attention_config, cache_config=cache_config attention_config=attention_config, cache_config=cache_config
) )
if CudaPlatform is None:
pytest.skip("CudaPlatform not available")
with ( with (
set_current_vllm_config(vllm_config), set_current_vllm_config(vllm_config),
patch("vllm.platforms.current_platform", CudaPlatform()), patch("vllm.platforms.current_platform", CudaPlatform()),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment