Unverified Commit ec10fd0a authored by Wenzheng Bi's avatar Wenzheng Bi Committed by GitHub
Browse files

[Bugfix] Move current_platform import to avoid python import cache. (#16601)


Signed-off-by: default avatariwzbi <wzbi@zju.edu.cn>
parent 0426e3c5
...@@ -84,12 +84,12 @@ def test_env( ...@@ -84,12 +84,12 @@ def test_env(
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0") m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
if device == "cpu": if device == "cpu":
with patch("vllm.attention.selector.current_platform", CpuPlatform()): with patch("vllm.platforms.current_platform", CpuPlatform()):
backend = get_attn_backend(16, torch.float16, None, block_size) backend = get_attn_backend(16, torch.float16, None, block_size)
assert backend.get_name() == "TORCH_SDPA" assert backend.get_name() == "TORCH_SDPA"
elif device == "hip": elif device == "hip":
with patch("vllm.attention.selector.current_platform", RocmPlatform()): with patch("vllm.platforms.current_platform", RocmPlatform()):
if use_mla: if use_mla:
# ROCm MLA backend logic: # ROCm MLA backend logic:
# - TRITON_MLA: supported when block_size != 1 # - TRITON_MLA: supported when block_size != 1
...@@ -126,7 +126,7 @@ def test_env( ...@@ -126,7 +126,7 @@ def test_env(
assert backend.get_name() == expected assert backend.get_name() == expected
elif device == "cuda": elif device == "cuda":
with patch("vllm.attention.selector.current_platform", CudaPlatform()): with patch("vllm.platforms.current_platform", CudaPlatform()):
if use_mla: if use_mla:
# CUDA MLA backend logic: # CUDA MLA backend logic:
# - CUTLASS_MLA: only supported with block_size == 128 # - CUTLASS_MLA: only supported with block_size == 128
...@@ -214,12 +214,12 @@ def test_env( ...@@ -214,12 +214,12 @@ def test_env(
def test_fp32_fallback(device: str): def test_fp32_fallback(device: str):
"""Test attention backend selection with fp32.""" """Test attention backend selection with fp32."""
if device == "cpu": if device == "cpu":
with patch("vllm.attention.selector.current_platform", CpuPlatform()): with patch("vllm.platforms.current_platform", CpuPlatform()):
backend = get_attn_backend(16, torch.float32, None, 16) backend = get_attn_backend(16, torch.float32, None, 16)
assert backend.get_name() == "TORCH_SDPA" assert backend.get_name() == "TORCH_SDPA"
elif device == "cuda": elif device == "cuda":
with patch("vllm.attention.selector.current_platform", CudaPlatform()): with patch("vllm.platforms.current_platform", CudaPlatform()):
backend = get_attn_backend(16, torch.float32, None, 16) backend = get_attn_backend(16, torch.float32, None, 16)
assert backend.get_name() == "FLEX_ATTENTION" assert backend.get_name() == "FLEX_ATTENTION"
...@@ -277,7 +277,7 @@ def test_invalid_env(monkeypatch: pytest.MonkeyPatch): ...@@ -277,7 +277,7 @@ def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
"""Test that invalid attention backend names raise ValueError.""" """Test that invalid attention backend names raise ValueError."""
with ( with (
monkeypatch.context() as m, monkeypatch.context() as m,
patch("vllm.attention.selector.current_platform", CudaPlatform()), patch("vllm.platforms.current_platform", CudaPlatform()),
): ):
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL) m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
......
...@@ -14,7 +14,6 @@ import vllm.envs as envs ...@@ -14,7 +14,6 @@ import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import _Backend, backend_name_to_enum from vllm.attention.backends.registry import _Backend, backend_name_to_enum
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -192,6 +191,8 @@ def _cached_get_attn_backend( ...@@ -192,6 +191,8 @@ def _cached_get_attn_backend(
) )
# get device-specific attn_backend # get device-specific attn_backend
from vllm.platforms import current_platform
attention_cls = current_platform.get_attn_backend_cls( attention_cls = current_platform.get_attn_backend_cls(
selected_backend, selected_backend,
head_size, head_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment