Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e48b2e68
Unverified
Commit
e48b2e68
authored
Nov 24, 2025
by
vllmellm
Committed by
GitHub
Nov 24, 2025
Browse files
[Bugfix] [ROCm] [UX] Reorganize ROCm Backend Selection Logic (#26980)
Signed-off-by:
vllmellm
<
vllm.ellm@embeddedllm.com
>
parent
7a228b53
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
394 additions
and
23 deletions
+394
-23
tests/v1/attention/test_rocm_attention_backends_selection.py
tests/v1/attention/test_rocm_attention_backends_selection.py
+337
-0
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+57
-23
No files found.
tests/v1/attention/test_rocm_attention_backends_selection.py
0 → 100644
View file @
e48b2e68
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for attention backend selectors."""
from
unittest.mock
import
MagicMock
,
patch
import
pytest
import
torch
from
vllm.attention.backends.registry
import
AttentionBackendEnum
from
vllm.platforms
import
current_platform
# ROCm-specific attention backend selection tests
pytestmark
=
pytest
.
mark
.
skipif
(
not
current_platform
.
is_rocm
(),
reason
=
"ROCm-specific tests"
)
@
pytest
.
fixture
def
mock_vllm_config
():
"""Create a mock VllmConfig for testing."""
config
=
MagicMock
()
config
.
model_config
.
dtype
=
torch
.
float16
config
.
model_config
.
hf_config
.
architectures
=
[
"LlamaForCausalLM"
]
config
.
cache_config
.
block_size
=
16
return
config
@
pytest
.
fixture
def
mock_on_gfx9
():
"""Mock the on_gfx9 function to return True."""
with
patch
(
"vllm.platforms.rocm.on_gfx9"
,
return_value
=
True
):
yield
@
pytest
.
mark
.
parametrize
(
"env_vars, selected_backend, expected_backend_path"
,
[
# Test Case 1: Default (no env vars, no explicit backend)
(
{},
None
,
AttentionBackendEnum
.
TRITON_ATTN
.
get_path
(),
),
# Test Case 2: Explicit TRITON_ATTN backend
(
{},
"TRITON_ATTN"
,
AttentionBackendEnum
.
TRITON_ATTN
.
get_path
(),
),
# Test Case 3: Explicit ROCM_ATTN backend
(
{},
"ROCM_ATTN"
,
AttentionBackendEnum
.
ROCM_ATTN
.
get_path
(),
),
# Test Case 4: Explicit ROCM_AITER_FA backend
(
{},
"ROCM_AITER_FA"
,
AttentionBackendEnum
.
ROCM_AITER_FA
.
get_path
(),
),
# Test Case 5: Explicit ROCM_AITER_UNIFIED_ATTN backend
(
{},
"ROCM_AITER_UNIFIED_ATTN"
,
AttentionBackendEnum
.
ROCM_AITER_UNIFIED_ATTN
.
get_path
(),
),
# Test Case 6: VLLM_ROCM_USE_AITER=1
# (defaults to AITER FA when MHA not explicitly disabled)
(
{
"VLLM_ROCM_USE_AITER"
:
"1"
},
None
,
AttentionBackendEnum
.
ROCM_AITER_FA
.
get_path
(),
),
# Test Case 7: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=1
(
{
"VLLM_ROCM_USE_AITER"
:
"1"
,
"VLLM_ROCM_USE_AITER_MHA"
:
"1"
},
None
,
AttentionBackendEnum
.
ROCM_AITER_FA
.
get_path
(),
),
# Test Case 8: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION=1
(
{
"VLLM_ROCM_USE_AITER"
:
"1"
,
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION"
:
"1"
,
},
None
,
AttentionBackendEnum
.
ROCM_AITER_UNIFIED_ATTN
.
get_path
(),
),
# Test Case 9: VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1
(
{
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION"
:
"1"
},
None
,
AttentionBackendEnum
.
ROCM_ATTN
.
get_path
(),
),
# Test Case 10: VLLM_ROCM_USE_AITER=1 + explicit TRITON_ATTN
(
{
"VLLM_ROCM_USE_AITER"
:
"1"
},
"TRITON_ATTN"
,
AttentionBackendEnum
.
TRITON_ATTN
.
get_path
(),
),
# Test Case 11: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0
# (explicitly disabled)
(
{
"VLLM_ROCM_USE_AITER"
:
"1"
,
"VLLM_ROCM_USE_AITER_MHA"
:
"0"
},
None
,
AttentionBackendEnum
.
TRITON_ATTN
.
get_path
(),
),
# Test Case 12: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN
(
{
"VLLM_ROCM_USE_AITER"
:
"1"
},
"ROCM_ATTN"
,
AttentionBackendEnum
.
ROCM_ATTN
.
get_path
(),
),
],
)
def
test_standard_attention_backend_selection
(
env_vars
,
selected_backend
,
expected_backend_path
,
mock_vllm_config
,
mock_on_gfx9
,
monkeypatch
,
):
"""Test standard attention backend selection with various configurations."""
# Set environment variables
for
key
,
value
in
env_vars
.
items
():
monkeypatch
.
setenv
(
key
,
value
)
# Import after setting env vars to ensure they're picked up
# Reload envs to pick up new environment variables
import
importlib
import
vllm.envs
as
envs
from
vllm.attention.backends.registry
import
_Backend
importlib
.
reload
(
envs
)
# Convert string backend to enum if provided
backend_enum
=
None
if
selected_backend
:
backend_enum
=
getattr
(
_Backend
,
selected_backend
)
# Get the backend class path
from
vllm.platforms.rocm
import
RocmPlatform
backend_path
=
RocmPlatform
.
get_attn_backend_cls
(
selected_backend
=
backend_enum
,
head_size
=
128
,
dtype
=
torch
.
float16
,
kv_cache_dtype
=
"auto"
,
block_size
=
16
,
use_mla
=
False
,
has_sink
=
False
,
use_sparse
=
False
,
)
assert
backend_path
==
expected_backend_path
@
pytest
.
mark
.
parametrize
(
"env_vars, selected_backend, block_size, expected_backend_path, should_raise"
,
[
# Test Case 1: TRITON_MLA with block_size != 1
(
{},
"TRITON_MLA"
,
16
,
AttentionBackendEnum
.
TRITON_MLA
.
get_path
(),
False
,
),
# Test Case 2: TRITON_MLA with block_size == 1 (should raise)
(
{},
"TRITON_MLA"
,
1
,
None
,
True
,
),
# Test Case 3: ROCM_AITER_MLA with block_size == 1
(
{},
"ROCM_AITER_MLA"
,
1
,
AttentionBackendEnum
.
ROCM_AITER_MLA
.
get_path
(),
False
,
),
# Test Case 4: ROCM_AITER_MLA with block_size != 1 (should raise)
(
{},
"ROCM_AITER_MLA"
,
16
,
AttentionBackendEnum
.
ROCM_AITER_MLA
.
get_path
(),
False
,
),
# Test Case 5: VLLM_ROCM_USE_AITER=1 with block_size == 1
(
{
"VLLM_ROCM_USE_AITER"
:
"1"
},
None
,
1
,
AttentionBackendEnum
.
ROCM_AITER_MLA
.
get_path
(),
False
,
),
# Test Case 6: VLLM_ROCM_USE_AITER=1 with block_size == 16
# (should use ROCM_AITER_MLA now, as it supports block_size 16)
(
{
"VLLM_ROCM_USE_AITER"
:
"1"
},
None
,
16
,
AttentionBackendEnum
.
ROCM_AITER_MLA
.
get_path
(),
False
,
),
# Test Case 7: VLLM_ROCM_USE_AITER=1 + explicit TRITON_MLA
(
{
"VLLM_ROCM_USE_AITER"
:
"1"
},
"TRITON_MLA"
,
16
,
AttentionBackendEnum
.
TRITON_MLA
.
get_path
(),
False
,
),
# Test Case 8: Explicit ROCM_AITER_TRITON_MLA
(
{},
"ROCM_AITER_TRITON_MLA"
,
16
,
AttentionBackendEnum
.
ROCM_AITER_TRITON_MLA
.
get_path
(),
False
,
),
],
)
def
test_mla_backend_selection
(
env_vars
,
selected_backend
,
block_size
,
expected_backend_path
,
should_raise
,
mock_vllm_config
,
monkeypatch
,
):
"""Test MLA backend selection with various configurations."""
# Set environment variables
for
key
,
value
in
env_vars
.
items
():
monkeypatch
.
setenv
(
key
,
value
)
# Import after setting env vars
# Reload envs
import
importlib
import
vllm.envs
as
envs
from
vllm.attention.backends.registry
import
_Backend
importlib
.
reload
(
envs
)
# Mock is_aiter_mla_enabled based on env vars and block_size
aiter_enabled
=
env_vars
.
get
(
"VLLM_ROCM_USE_AITER"
)
==
"1"
mock_rocm_ops
=
MagicMock
()
mock_rocm_ops
.
is_mla_enabled
.
return_value
=
aiter_enabled
mock_aiter_module
=
MagicMock
()
mock_aiter_module
.
rocm_aiter_ops
=
mock_rocm_ops
with
patch
.
dict
(
"sys.modules"
,
{
"vllm._aiter_ops"
:
mock_aiter_module
}):
# Convert string backend to enum if provided
backend_enum
=
None
if
selected_backend
:
backend_enum
=
getattr
(
_Backend
,
selected_backend
)
from
vllm.platforms.rocm
import
RocmPlatform
if
should_raise
:
with
pytest
.
raises
(
ValueError
):
RocmPlatform
.
get_attn_backend_cls
(
selected_backend
=
backend_enum
,
head_size
=
128
,
dtype
=
torch
.
float16
,
kv_cache_dtype
=
"auto"
,
block_size
=
block_size
,
use_mla
=
True
,
has_sink
=
False
,
use_sparse
=
False
,
)
else
:
backend_path
=
RocmPlatform
.
get_attn_backend_cls
(
selected_backend
=
backend_enum
,
head_size
=
128
,
dtype
=
torch
.
float16
,
kv_cache_dtype
=
"auto"
,
block_size
=
block_size
,
use_mla
=
True
,
has_sink
=
False
,
use_sparse
=
False
,
)
assert
backend_path
==
expected_backend_path
def
test_aiter_fa_requires_gfx9
(
mock_vllm_config
):
"""Test that ROCM_AITER_FA requires gfx9 architecture."""
from
vllm.attention.backends.registry
import
_Backend
from
vllm.platforms.rocm
import
RocmPlatform
# Mock on_gfx9 to return False
with
(
patch
(
"vllm.platforms.rocm.on_gfx9"
,
return_value
=
False
),
pytest
.
raises
(
ValueError
,
match
=
"only supported on gfx9"
,
),
):
RocmPlatform
.
get_attn_backend_cls
(
selected_backend
=
_Backend
.
ROCM_AITER_FA
,
head_size
=
128
,
dtype
=
torch
.
float16
,
kv_cache_dtype
=
"auto"
,
block_size
=
16
,
use_mla
=
False
,
has_sink
=
False
,
use_sparse
=
False
,
)
def
test_sparse_not_supported
(
mock_vllm_config
):
"""Test that sparse attention is not supported on ROCm."""
from
vllm.platforms.rocm
import
RocmPlatform
with
pytest
.
raises
(
AssertionError
,
match
=
"Sparse MLA backend on ROCm only supports block size 1"
):
RocmPlatform
.
get_attn_backend_cls
(
selected_backend
=
None
,
head_size
=
128
,
dtype
=
torch
.
float16
,
kv_cache_dtype
=
"auto"
,
block_size
=
16
,
use_mla
=
False
,
has_sink
=
False
,
use_sparse
=
True
,
)
vllm/platforms/rocm.py
View file @
e48b2e68
...
...
@@ -262,30 +262,64 @@ class RocmPlatform(Platform):
f
"is not MLA type while requested for MLA backend."
)
if
selected_backend
==
AttentionBackendEnum
.
FLEX_ATTENTION
:
logger
.
info
(
"Using FlexAttention backend."
)
return
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
if
(
rocm_aiter_ops
.
is_mha_enabled
()
)
or
selected_backend
==
AttentionBackendEnum
.
ROCM_AITER_FA
:
logger
.
info
(
"Using Aiter Flash Attention backend."
)
return
AttentionBackendEnum
.
ROCM_AITER_FA
.
get_path
()
if
(
rocm_aiter_ops
.
is_triton_unified_attn_enabled
()
)
or
selected_backend
==
AttentionBackendEnum
.
ROCM_AITER_UNIFIED_ATTN
:
logger
.
info
(
"Using Aiter Unified Attention backend."
)
return
AttentionBackendEnum
.
ROCM_AITER_UNIFIED_ATTN
.
get_path
()
if
(
envs
.
VLLM_V1_USE_PREFILL_DECODE_ATTENTION
or
selected_backend
==
AttentionBackendEnum
.
ROCM_ATTN
):
# rocm specific backend, with aiter and/or
# triton prefix-prefill
logger
.
info
(
"Using Rocm Attention backend."
)
if
selected_backend
==
AttentionBackendEnum
.
TRITON_ATTN
:
logger
.
info
(
"Using Triton Attention backend on V1 engine."
)
return
AttentionBackendEnum
.
TRITON_ATTN
.
get_path
()
if
selected_backend
==
AttentionBackendEnum
.
ROCM_ATTN
:
logger
.
info
(
"Using Rocm Attention backend on V1 engine."
)
return
AttentionBackendEnum
.
ROCM_ATTN
.
get_path
()
# default case, using triton unified attention
logger
.
info
(
"Using Triton Attention backend."
)
return
AttentionBackendEnum
.
TRITON_ATTN
.
get_path
()
if
selected_backend
==
AttentionBackendEnum
.
ROCM_AITER_FA
:
if
on_gfx9
():
logger
.
info
(
"Using Aiter Flash Attention backend on V1 engine."
)
return
AttentionBackendEnum
.
ROCM_AITER_FA
.
get_path
()
else
:
raise
ValueError
(
f
"The selected backend,
{
selected_backend
.
name
}
, "
"is only supported on gfx9 architectures."
)
if
selected_backend
==
AttentionBackendEnum
.
ROCM_AITER_UNIFIED_ATTN
:
logger
.
info
(
"Using Aiter Unified Attention backend on V1 engine."
)
return
AttentionBackendEnum
.
ROCM_AITER_UNIFIED_ATTN
.
get_path
()
# Handle automatic backend selection based on environment variables
if
selected_backend
is
None
:
# Priority 1: Check for AITER Unified Attention (must check before MHA)
if
envs
.
VLLM_ROCM_USE_AITER
and
envs
.
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
:
logger
.
info
(
"Using Aiter Unified Attention backend on V1 engine."
)
return
AttentionBackendEnum
.
ROCM_AITER_UNIFIED_ATTN
.
get_path
()
# Priority 2: Check for AITER MHA (Flash Attention)
# Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
if
envs
.
VLLM_ROCM_USE_AITER
and
envs
.
VLLM_ROCM_USE_AITER_MHA
and
on_gfx9
():
logger
.
info
(
"Using Aiter Flash Attention backend on V1 engine."
)
return
AttentionBackendEnum
.
ROCM_AITER_FA
.
get_path
()
# Priority 3: Check for ROCM_ATTN (prefill-decode split)
if
envs
.
VLLM_V1_USE_PREFILL_DECODE_ATTENTION
:
logger
.
info
(
"Using Rocm Attention backend on V1 engine."
)
return
AttentionBackendEnum
.
ROCM_ATTN
.
get_path
()
# Priority 4: Check for AITER enabled without specific flags
# This defaults to AITER FA only if MHA is not explicitly disabled
if
(
envs
.
VLLM_ROCM_USE_AITER
and
on_gfx9
()
and
envs
.
VLLM_ROCM_USE_AITER_MHA
is
not
False
):
logger
.
info
(
"Using Aiter Flash Attention backend on V1 engine."
)
return
AttentionBackendEnum
.
ROCM_AITER_FA
.
get_path
()
# Default: Triton Unified Attention
logger
.
info
(
"Using Triton Attention backend on V1 engine."
)
return
AttentionBackendEnum
.
TRITON_ATTN
.
get_path
()
raise
RuntimeError
(
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
"to select a supported backend."
)
@
classmethod
def
set_device
(
cls
,
device
:
torch
.
device
)
->
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment