Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8c760b6a
Unverified
Commit
8c760b6a
authored
Mar 05, 2026
by
Sage Moore
Committed by
GitHub
Mar 05, 2026
Browse files
[ROCm] Refactor ROCm attention backend selection logic (#35246)
Signed-off-by:
Sage Moore
<
sage@neuralmagic.com
>
parent
3ee68590
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
170 additions
and
114 deletions
+170
-114
docs/design/attention_backends.md
docs/design/attention_backends.md
+1
-1
tests/kernels/attention/test_attention_selector.py
tests/kernels/attention/test_attention_selector.py
+4
-5
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+136
-104
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
+14
-4
vllm/v1/attention/backends/mla/triton_mla.py
vllm/v1/attention/backends/mla/triton_mla.py
+5
-0
vllm/v1/attention/backends/rocm_aiter_fa.py
vllm/v1/attention/backends/rocm_aiter_fa.py
+10
-0
No files found.
docs/design/attention_backends.md
View file @
8c760b6a
...
@@ -211,6 +211,6 @@ configuration.
...
@@ -211,6 +211,6 @@ configuration.
|
`FLASHMLA_SPARSE`
| bf16 |
`auto`
,
`bfloat16`
,
`fp8_ds_mla`
| 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
|
`FLASHMLA_SPARSE`
| bf16 |
`auto`
,
`bfloat16`
,
`fp8_ds_mla`
| 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
|
`FLASH_ATTN_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
| %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x |
|
`FLASH_ATTN_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
| %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x |
|
`ROCM_AITER_MLA`
| fp16, bf16 |
`auto`
| 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_MLA`
| fp16, bf16 |
`auto`
| 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_MLA_SPARSE`
|
fp16,
bf16 |
`auto`
| Any | 576 | ❌ |
❌
| ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_MLA_SPARSE`
| bf16 |
`auto`
| Any | 576 | ❌ |
✅
| ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_TRITON_MLA`
| fp16, bf16 |
`auto`
| Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_TRITON_MLA`
| fp16, bf16 |
`auto`
| Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
`TRITON_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
| Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
|
`TRITON_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
| Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
tests/kernels/attention/test_attention_selector.py
View file @
8c760b6a
...
@@ -103,21 +103,20 @@ def test_backend_selection(
...
@@ -103,21 +103,20 @@ def test_backend_selection(
if
name
==
"TRITON_MLA"
and
block_size
==
1
:
if
name
==
"TRITON_MLA"
and
block_size
==
1
:
# TRITON_MLA doesn't support block_size == 1
# TRITON_MLA doesn't support block_size == 1
with
pytest
.
raises
(
ValueError
)
as
exc_info
:
with
pytest
.
raises
(
ValueError
):
get_attn_backend
(
get_attn_backend
(
1
6
,
torch
.
float16
,
None
,
block_size
,
use_mla
=
use_mla
57
6
,
torch
.
float16
,
None
,
block_size
,
use_mla
=
use_mla
)
)
assert
f
"The selected backend,
{
name
}
"
in
str
(
exc_info
.
value
)
else
:
else
:
# Valid backend-block_size combination
# Valid backend-block_size combination
backend
=
get_attn_backend
(
backend
=
get_attn_backend
(
1
6
,
torch
.
float16
,
None
,
block_size
,
use_mla
=
use_mla
57
6
,
torch
.
float16
,
None
,
block_size
,
use_mla
=
use_mla
)
)
expected
=
name
expected
=
name
assert
backend
.
get_name
()
==
expected
assert
backend
.
get_name
()
==
expected
else
:
else
:
backend
=
get_attn_backend
(
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
None
,
block_size
,
use_mla
=
use_mla
32
,
torch
.
float16
,
None
,
block_size
,
use_mla
=
use_mla
)
)
expected
=
"ROCM_ATTN"
expected
=
"ROCM_ATTN"
assert
backend
.
get_name
()
==
expected
assert
backend
.
get_name
()
==
expected
...
...
vllm/platforms/rocm.py
View file @
8c760b6a
...
@@ -306,6 +306,52 @@ def flash_attn_triton_available() -> bool:
...
@@ -306,6 +306,52 @@ def flash_attn_triton_available() -> bool:
return
False
return
False
def
_get_backend_priorities
(
use_mla
:
bool
,
use_sparse
:
bool
,
)
->
list
[
AttentionBackendEnum
]:
from
vllm._aiter_ops
import
rocm_aiter_ops
if
use_sparse
:
return
[
AttentionBackendEnum
.
ROCM_AITER_MLA_SPARSE
]
if
use_mla
:
if
rocm_aiter_ops
.
is_mla_enabled
():
return
[
AttentionBackendEnum
.
ROCM_AITER_MLA
,
AttentionBackendEnum
.
TRITON_MLA
,
AttentionBackendEnum
.
ROCM_AITER_TRITON_MLA
,
]
else
:
return
[
AttentionBackendEnum
.
TRITON_MLA
,
]
backends
=
[]
# Priority 1: Check for AITER Unified Attention (must check before MHA)
if
envs
.
VLLM_ROCM_USE_AITER
and
envs
.
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
:
backends
.
append
(
AttentionBackendEnum
.
ROCM_AITER_UNIFIED_ATTN
)
# Priority 2: Check for AITER MHA (Flash Attention)
if
envs
.
VLLM_ROCM_USE_AITER
and
envs
.
VLLM_ROCM_USE_AITER_MHA
:
backends
.
append
(
AttentionBackendEnum
.
ROCM_AITER_FA
)
# Priority 3: Check for ROCM_ATTN (prefill-decode split)
from
vllm.config
import
get_current_vllm_config_or_none
vllm_config
=
get_current_vllm_config_or_none
()
if
(
vllm_config
is
not
None
and
vllm_config
.
attention_config
.
use_prefill_decode_attention
):
backends
.
append
(
AttentionBackendEnum
.
ROCM_ATTN
)
# Default: Triton Unified Attention
backends
.
append
(
AttentionBackendEnum
.
TRITON_ATTN
)
return
backends
class
RocmPlatform
(
Platform
):
class
RocmPlatform
(
Platform
):
_enum
=
PlatformEnum
.
ROCM
_enum
=
PlatformEnum
.
ROCM
device_name
:
str
=
"rocm"
device_name
:
str
=
"rocm"
...
@@ -349,6 +395,39 @@ class RocmPlatform(Platform):
...
@@ -349,6 +395,39 @@ class RocmPlatform(Platform):
with
contextlib
.
suppress
(
ImportError
):
with
contextlib
.
suppress
(
ImportError
):
import
vllm._rocm_C
# noqa: F401
import
vllm._rocm_C
# noqa: F401
@
classmethod
def
get_valid_backends
(
cls
,
device_capability
:
DeviceCapability
,
attn_selector_config
:
"AttentionSelectorConfig"
,
num_heads
:
int
|
None
=
None
,
)
->
tuple
[
list
[
tuple
[
"AttentionBackendEnum"
,
int
]],
dict
[
"AttentionBackendEnum"
,
list
[
str
]],
]:
valid_backends_priorities
=
[]
invalid_reasons
=
{}
backend_priorities
=
_get_backend_priorities
(
attn_selector_config
.
use_mla
,
attn_selector_config
.
use_sparse
,
)
for
priority
,
backend
in
enumerate
(
backend_priorities
):
try
:
backend_class
=
backend
.
get_class
()
invalid_reasons_i
=
backend_class
.
validate_configuration
(
device_capability
=
device_capability
,
**
attn_selector_config
.
_asdict
(),
)
except
ImportError
:
invalid_reasons_i
=
[
"ImportError"
]
if
invalid_reasons_i
:
invalid_reasons
[
backend
]
=
invalid_reasons_i
else
:
valid_backends_priorities
.
append
((
backend
,
priority
))
return
valid_backends_priorities
,
invalid_reasons
@
classmethod
@
classmethod
def
get_attn_backend_cls
(
def
get_attn_backend_cls
(
cls
,
cls
,
...
@@ -356,118 +435,71 @@ class RocmPlatform(Platform):
...
@@ -356,118 +435,71 @@ class RocmPlatform(Platform):
attn_selector_config
:
"AttentionSelectorConfig"
,
attn_selector_config
:
"AttentionSelectorConfig"
,
num_heads
:
int
|
None
=
None
,
num_heads
:
int
|
None
=
None
,
)
->
str
:
)
->
str
:
from
vllm._aiter_ops
import
rocm_aiter_ops
device_capability
=
cls
.
get_device_capability
()
assert
device_capability
is
not
None
block_size
=
attn_selector_config
.
block_size
kv_cache_dtype
=
attn_selector_config
.
kv_cache_dtype
if
attn_selector_config
.
use_sparse
:
# First try checking just the selected backend, if there is one.
if
kv_cache_dtype
and
kv_cache_dtype
.
startswith
(
"fp8"
):
if
selected_backend
is
not
None
:
try
:
backend_class
=
selected_backend
.
get_class
()
invalid_reasons
=
backend_class
.
validate_configuration
(
device_capability
=
device_capability
,
**
attn_selector_config
.
_asdict
(),
)
except
ImportError
:
invalid_reasons
=
[
"ImportError"
]
if
invalid_reasons
:
raise
ValueError
(
raise
ValueError
(
"ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
f
"Selected backend
{
selected_backend
}
is not valid for "
f
"this configuration. Reason:
{
invalid_reasons
}
"
)
)
assert
block_size
==
1
,
(
else
:
"Sparse MLA backend on ROCm only supports block size 1 for now."
logger
.
info
(
"Using %s backend."
,
selected_backend
)
return
selected_backend
.
get_path
()
# No selected backend or the selected backend is invalid,
# so we try finding a valid backend.
valid_backends_priorities
,
invalid_reasons
=
cls
.
get_valid_backends
(
device_capability
=
device_capability
,
attn_selector_config
=
attn_selector_config
,
num_heads
=
num_heads
,
)
)
logger
.
info_once
(
"Using Sparse MLA backend."
)
reasons_str
=
(
return
AttentionBackendEnum
.
ROCM_AITER_MLA_SPARSE
.
get_path
()
"{"
+
", "
.
join
(
if
attn_selector_config
.
use_mla
:
f
"
{
backend
.
name
}
: [
{
', '
.
join
(
reasons
)
}
]"
if
selected_backend
is
None
:
for
backend
,
reasons
in
invalid_reasons
.
items
()
selected_backend
=
(
AttentionBackendEnum
.
ROCM_AITER_MLA
if
rocm_aiter_ops
.
is_mla_enabled
()
or
block_size
==
1
else
AttentionBackendEnum
.
TRITON_MLA
)
)
if
selected_backend
==
AttentionBackendEnum
.
TRITON_MLA
:
+
"}"
if
block_size
!=
1
:
logger
.
info_once
(
"Using Triton MLA backend."
)
return
AttentionBackendEnum
.
TRITON_MLA
.
get_path
()
raise
ValueError
(
f
" The selected backend,
{
selected_backend
.
name
}
,"
f
"does not support block size
{
block_size
}
."
)
)
if
selected_backend
==
AttentionBackendEnum
.
ROCM_AITER_MLA
:
config_str
=
attn_selector_config
.
__repr__
()
logger
.
info
(
"Using AITER MLA backend."
)
logger
.
debug_once
(
return
AttentionBackendEnum
.
ROCM_AITER_MLA
.
get_path
()
f
"Some attention backends are not valid for
{
cls
.
device_name
}
with "
if
selected_backend
==
AttentionBackendEnum
.
ROCM_AITER_TRITON_MLA
:
f
"
{
config_str
}
. Reasons:
{
reasons_str
}
."
logger
.
info
(
"Using AITER TRITON MLA backend."
)
return
AttentionBackendEnum
.
ROCM_AITER_TRITON_MLA
.
get_path
()
raise
ValueError
(
f
" The selected backend,
{
selected_backend
.
name
}
,"
f
"is not MLA type while requested for MLA backend."
)
)
if
len
(
valid_backends_priorities
)
==
0
:
if
selected_backend
==
AttentionBackendEnum
.
FLEX_ATTENTION
:
logger
.
info
(
"Using FlexAttention backend."
)
return
AttentionBackendEnum
.
FLEX_ATTENTION
.
get_path
()
if
selected_backend
==
AttentionBackendEnum
.
TRITON_ATTN
:
logger
.
info
(
"Using Triton Attention backend."
)
return
AttentionBackendEnum
.
TRITON_ATTN
.
get_path
()
if
selected_backend
==
AttentionBackendEnum
.
ROCM_ATTN
:
logger
.
info
(
"Using Rocm Attention backend."
)
return
AttentionBackendEnum
.
ROCM_ATTN
.
get_path
()
if
selected_backend
==
AttentionBackendEnum
.
ROCM_AITER_FA
:
if
on_gfx9
():
logger
.
info
(
"Using Aiter Flash Attention backend."
)
return
AttentionBackendEnum
.
ROCM_AITER_FA
.
get_path
()
else
:
raise
ValueError
(
raise
ValueError
(
f
"The selected backend,
{
selected_backend
.
name
}
,
"
f
"No valid attention backend found for
{
cls
.
device_
name
}
"
"is only supported on gfx9 architectures
."
f
"with
{
config_str
}
. Reasons:
{
reasons_str
}
."
)
)
if
selected_backend
==
AttentionBackendEnum
.
ROCM_AITER_UNIFIED_ATTN
:
# We have found some valid backends. Select the one with the
logger
.
info
(
"Using Aiter Unified Attention backend."
)
# highest priority.
return
AttentionBackendEnum
.
ROCM_AITER_UNIFIED_ATTN
.
get_path
()
sorted_indices
=
sorted
(
range
(
len
(
valid_backends_priorities
)),
# Handle automatic backend selection based on environment variables
key
=
lambda
i
:
valid_backends_priorities
[
i
][
1
],
if
selected_backend
is
None
:
)
# Priority 1: Check for AITER Unified Attention (must check before MHA)
selected_index
=
sorted_indices
[
0
]
if
envs
.
VLLM_ROCM_USE_AITER
and
envs
.
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
:
selected_backend
=
valid_backends_priorities
[
selected_index
][
0
]
logger
.
info
(
"Using Aiter Unified Attention backend."
)
logger
.
info_once
(
return
AttentionBackendEnum
.
ROCM_AITER_UNIFIED_ATTN
.
get_path
()
"Using %s attention backend out of potential backends: %s."
,
selected_backend
.
name
,
# Priority 2: Check for AITER MHA (Flash Attention)
"["
+
", "
.
join
(
f
"'
{
b
[
0
].
name
}
'"
for
b
in
valid_backends_priorities
)
+
"]"
,
# Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1)
scope
=
"local"
,
if
envs
.
VLLM_ROCM_USE_AITER
and
envs
.
VLLM_ROCM_USE_AITER_MHA
and
on_gfx9
():
logger
.
info
(
"Using Aiter Flash Attention backend."
)
return
AttentionBackendEnum
.
ROCM_AITER_FA
.
get_path
()
# Priority 3: Check for ROCM_ATTN (prefill-decode split)
from
vllm.config
import
get_current_vllm_config_or_none
vllm_config
=
get_current_vllm_config_or_none
()
if
(
vllm_config
is
not
None
and
vllm_config
.
attention_config
.
use_prefill_decode_attention
):
logger
.
info
(
"Using Rocm Attention backend."
)
return
AttentionBackendEnum
.
ROCM_ATTN
.
get_path
()
# Priority 4: Check for AITER enabled without specific flags
# This defaults to AITER FA only if MHA is not explicitly disabled
if
(
envs
.
VLLM_ROCM_USE_AITER
and
on_gfx9
()
and
envs
.
VLLM_ROCM_USE_AITER_MHA
is
not
False
):
logger
.
info
(
"Using Aiter Flash Attention backend."
)
return
AttentionBackendEnum
.
ROCM_AITER_FA
.
get_path
()
# Default: Triton Unified Attention
logger
.
info
(
"Using Triton Attention backend."
)
return
AttentionBackendEnum
.
TRITON_ATTN
.
get_path
()
raise
RuntimeError
(
f
"Attention backend
{
selected_backend
.
name
}
is not supported on "
"ROCm. Note that V0 attention backends have been removed."
)
)
return
selected_backend
.
get_path
()
@
classmethod
@
classmethod
def
get_supported_vit_attn_backends
(
cls
)
->
list
[
"AttentionBackendEnum"
]:
def
get_supported_vit_attn_backends
(
cls
)
->
list
[
"AttentionBackendEnum"
]:
return
[
return
[
...
...
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
View file @
8c760b6a
...
@@ -77,6 +77,7 @@ def fetch_id_to_ragged_triton(
...
@@ -77,6 +77,7 @@ def fetch_id_to_ragged_triton(
class
ROCMAiterMLASparseBackend
(
AttentionBackend
):
class
ROCMAiterMLASparseBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
accept_output_buffer
:
bool
=
True
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
bfloat16
]
@
staticmethod
@
staticmethod
def
get_name
()
->
str
:
def
get_name
()
->
str
:
...
@@ -104,14 +105,23 @@ class ROCMAiterMLASparseBackend(AttentionBackend):
...
@@ -104,14 +105,23 @@ class ROCMAiterMLASparseBackend(AttentionBackend):
)
->
tuple
[
int
,
...]:
)
->
tuple
[
int
,
...]:
return
(
num_blocks
,
block_size
,
head_size
)
return
(
num_blocks
,
block_size
,
head_size
)
@
classmethod
def
get_supported_dtypes
(
cls
)
->
list
[
torch
.
dtype
]:
return
[
torch
.
bfloat16
]
@
classmethod
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
return
[
576
]
return
[
576
]
@
classmethod
def
is_mla
(
cls
)
->
bool
:
return
True
@
classmethod
def
is_sparse
(
cls
)
->
bool
:
return
True
@
classmethod
def
supports_block_size
(
cls
,
block_size
:
int
|
None
)
->
bool
:
# The only supported block_size is 1
return
block_size
is
None
or
block_size
==
1
@
dataclass
@
dataclass
class
ROCMAiterMLASparseMetadata
(
AttentionMetadata
):
class
ROCMAiterMLASparseMetadata
(
AttentionMetadata
):
...
...
vllm/v1/attention/backends/mla/triton_mla.py
View file @
8c760b6a
...
@@ -45,6 +45,11 @@ class TritonMLABackend(MLACommonBackend):
...
@@ -45,6 +45,11 @@ class TritonMLABackend(MLACommonBackend):
def
supports_compute_capability
(
cls
,
capability
:
DeviceCapability
)
->
bool
:
def
supports_compute_capability
(
cls
,
capability
:
DeviceCapability
)
->
bool
:
return
True
return
True
@
classmethod
def
supports_block_size
(
cls
,
block_size
:
int
|
None
)
->
bool
:
# The only unsupported block_size is 1
return
block_size
is
None
or
block_size
!=
1
class
TritonMLAImpl
(
MLACommonImpl
[
MLACommonMetadata
]):
class
TritonMLAImpl
(
MLACommonImpl
[
MLACommonMetadata
]):
can_return_lse_for_decode
:
bool
=
True
can_return_lse_for_decode
:
bool
=
True
...
...
vllm/v1/attention/backends/rocm_aiter_fa.py
View file @
8c760b6a
...
@@ -12,6 +12,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
...
@@ -12,6 +12,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.platforms.interface
import
DeviceCapability
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.platform_utils
import
num_compute_units
from
vllm.utils.platform_utils
import
num_compute_units
from
vllm.v1.attention.backend
import
(
from
vllm.v1.attention.backend
import
(
...
@@ -766,6 +767,15 @@ class AiterFlashAttentionBackend(AttentionBackend):
...
@@ -766,6 +767,15 @@ class AiterFlashAttentionBackend(AttentionBackend):
raise
ValueError
(
"Block size must be a multiple of 16."
)
raise
ValueError
(
"Block size must be a multiple of 16."
)
return
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
return
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
classmethod
def
supports_compute_capability
(
cls
,
capability
:
DeviceCapability
)
->
bool
:
from
vllm.platforms.rocm
import
on_mi3xx
# DeviceCapability is currently created using torch.cuda.get_device_capability()
# which is known to be buggy on rocm systems. on_mi3xx uses amd-smi which is
# more reliable.
return
on_mi3xx
()
class
AiterFlashAttentionImpl
(
AttentionImpl
):
class
AiterFlashAttentionImpl
(
AttentionImpl
):
def
__init__
(
def
__init__
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment