Unverified Commit 18e01a0a authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[Misc] Add `--attention-backend auto` option (#35738)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent 6cb90109
...@@ -293,6 +293,48 @@ def test_invalid_backend(): ...@@ -293,6 +293,48 @@ def test_invalid_backend():
AttentionConfig(backend=AttentionBackendEnum["INVALID"]) AttentionConfig(backend=AttentionBackendEnum["INVALID"])
@pytest.mark.parametrize("auto_value", ["auto", "AUTO", "Auto"])
def test_auto_backend_string(auto_value: str):
"""Test that 'auto' string value triggers automatic backend selection."""
# Using "auto" should result in backend=None (automatic selection)
attention_config = AttentionConfig(backend=auto_value)
assert attention_config.backend is None
def test_auto_backend_selection_behavior():
"""Test that 'auto' backend behaves same as None (automatic selection)."""
# Create config with explicit "auto"
auto_config = AttentionConfig(backend="auto")
# Create config with None (default)
none_config = AttentionConfig(backend=None)
# Both should have backend=None
assert auto_config.backend is None
assert none_config.backend is None
# Both configs should result in the same automatic backend selection
vllm_config_auto = VllmConfig(attention_config=auto_config)
vllm_config_none = VllmConfig(attention_config=none_config)
with (
set_current_vllm_config(vllm_config_auto),
patch("vllm.platforms.current_platform", CpuPlatform()),
):
backend_auto = get_attn_backend(16, torch.float16, None, 16)
_cached_get_attn_backend.cache_clear()
with (
set_current_vllm_config(vllm_config_none),
patch("vllm.platforms.current_platform", CpuPlatform()),
):
backend_none = get_attn_backend(16, torch.float16, None, 16)
# Both should select the same backend
assert backend_auto.get_name() == backend_none.get_name()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"backend_name,flash_attn_version,should_succeed", "backend_name,flash_attn_version,should_succeed",
[ [
......
...@@ -14,7 +14,7 @@ class AttentionConfig: ...@@ -14,7 +14,7 @@ class AttentionConfig:
"""Configuration for attention mechanisms in vLLM.""" """Configuration for attention mechanisms in vLLM."""
backend: AttentionBackendEnum | None = None backend: AttentionBackendEnum | None = None
"""Attention backend to use. If None, will be selected automatically.""" """Attention backend to use. Use "auto" or None for automatic selection."""
flash_attn_version: Literal[2, 3, 4] | None = None flash_attn_version: Literal[2, 3, 4] | None = None
"""Force vllm to use a specific flash-attention version (2, 3, or 4). """Force vllm to use a specific flash-attention version (2, 3, or 4).
...@@ -63,7 +63,13 @@ class AttentionConfig: ...@@ -63,7 +63,13 @@ class AttentionConfig:
@field_validator("backend", mode="before") @field_validator("backend", mode="before")
@classmethod @classmethod
def validate_backend_before(cls, value: Any) -> Any: def validate_backend_before(cls, value: Any) -> Any:
"""Enable parsing of the `backend` enum type from string.""" """Enable parsing of the `backend` enum type from string.
The special value "auto" is treated as None, which triggers
automatic backend selection.
"""
if isinstance(value, str): if isinstance(value, str):
if value.lower() == "auto":
return None
return AttentionBackendEnum[value.upper()] return AttentionBackendEnum[value.upper()]
return value return value
...@@ -1816,13 +1816,10 @@ class EngineArgs: ...@@ -1816,13 +1816,10 @@ class EngineArgs:
"attention_backend and attention_config.backend " "attention_backend and attention_config.backend "
"are mutually exclusive" "are mutually exclusive"
) )
# Convert string to enum if needed (CLI parsing returns a string) # Reuse the validator to handle "auto" and string-to-enum conversion
if isinstance(self.attention_backend, str): attention_config.backend = AttentionConfig.validate_backend_before(
attention_config.backend = AttentionBackendEnum[ self.attention_backend
self.attention_backend.upper() )
]
else:
attention_config.backend = self.attention_backend
# Kernel config overrides # Kernel config overrides
kernel_config = copy.deepcopy(self.kernel_config) kernel_config = copy.deepcopy(self.kernel_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment