Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c5362c73
Unverified
Commit
c5362c73
authored
Mar 05, 2026
by
Rohan Potdar
Committed by
GitHub
Mar 05, 2026
Browse files
Reenable features for ROCm attention backends (#36185)
Signed-off-by:
Rohan138
<
rohanpotdar138@gmail.com
>
parent
0a49676f
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
66 additions
and
33 deletions
+66
-33
docs/design/attention_backends.md
docs/design/attention_backends.md
+5
-5
vllm/v1/attention/backend.py
vllm/v1/attention/backend.py
+1
-1
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+10
-0
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
+11
-10
vllm/v1/attention/backends/mla/triton_mla.py
vllm/v1/attention/backends/mla/triton_mla.py
+0
-5
vllm/v1/attention/backends/rocm_aiter_fa.py
vllm/v1/attention/backends/rocm_aiter_fa.py
+8
-0
vllm/v1/attention/backends/rocm_aiter_unified_attn.py
vllm/v1/attention/backends/rocm_aiter_unified_attn.py
+17
-1
vllm/v1/attention/backends/rocm_attn.py
vllm/v1/attention/backends/rocm_attn.py
+14
-11
No files found.
docs/design/attention_backends.md
View file @
c5362c73
...
@@ -171,9 +171,9 @@ Priority is **1 = highest** (tried first).
...
@@ -171,9 +171,9 @@ Priority is **1 = highest** (tried first).
|
`FLASH_ATTN`
| FA4
*
| fp16, bf16 |
`auto`
,
`bfloat16`
| %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 |
|
`FLASH_ATTN`
| FA4
*
| fp16, bf16 |
`auto`
,
`bfloat16`
| %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 |
|
`FLASH_ATTN_DIFFKV`
| | fp16, bf16 |
`auto`
| Any | Any | ❌ | ❌ | ✅ | Decoder | Any |
|
`FLASH_ATTN_DIFFKV`
| | fp16, bf16 |
`auto`
| Any | Any | ❌ | ❌ | ✅ | Decoder | Any |
|
`FLEX_ATTENTION`
| | fp16, bf16, fp32 |
`auto`
,
`bfloat16`
| Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
|
`FLEX_ATTENTION`
| | fp16, bf16, fp32 |
`auto`
,
`bfloat16`
| Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
|
`ROCM_AITER_FA`
| | fp16, bf16 |
`auto`
| 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_FA`
| | fp16, bf16 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_UNIFIED_ATTN`
| | fp16, bf16 |
`auto`
|
Any
| Any |
❌
|
❌
| ❌ | All | N/A |
|
`ROCM_AITER_UNIFIED_ATTN`
| | fp16, bf16 |
`auto`
|
%16
| Any |
✅
|
✅
| ❌ | All | N/A |
|
`ROCM_ATTN`
| | fp16, bf16, fp32 |
`auto`
| 16, 32, 544 | 32, 64, 80, 96, 128, 160, 192, 224, 256 |
❌
|
❌
| ❌ | All | N/A |
|
`ROCM_ATTN`
| | fp16, bf16, fp32 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| 16, 32, 544 | 32, 64, 80, 96, 128, 160, 192, 224, 256 |
✅
|
✅
| ❌ | All | N/A |
|
`TREE_ATTN`
| | fp16, bf16 |
`auto`
| %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
|
`TREE_ATTN`
| | fp16, bf16 |
`auto`
| %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
|
`TRITON_ATTN`
| | fp16, bf16, fp32 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| %16 | Any | ✅ | ✅ | ❌ | All | Any |
|
`TRITON_ATTN`
| | fp16, bf16, fp32 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| %16 | Any | ✅ | ✅ | ❌ | All | Any |
...
@@ -210,7 +210,7 @@ configuration.
...
@@ -210,7 +210,7 @@ configuration.
|
`FLASHMLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
| 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x |
|
`FLASHMLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
| 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x |
|
`FLASHMLA_SPARSE`
| bf16 |
`auto`
,
`bfloat16`
,
`fp8_ds_mla`
| 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
|
`FLASHMLA_SPARSE`
| bf16 |
`auto`
,
`bfloat16`
,
`fp8_ds_mla`
| 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
|
`FLASH_ATTN_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
| %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x |
|
`FLASH_ATTN_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
| %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x |
|
`ROCM_AITER_MLA`
| fp16, bf16 |
`auto`
| 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_MLA_SPARSE`
| bf16 |
`auto`
| Any
| 576
| ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_MLA_SPARSE`
|
fp16,
bf16 |
`auto`
,
`bfloat16`
| 1
| Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_TRITON_MLA`
| fp16, bf16 |
`auto`
| Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_TRITON_MLA`
| fp16, bf16 |
`auto`
| Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
|
`TRITON_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
| Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
|
`TRITON_MLA`
| fp16, bf16 |
`auto`
,
`bfloat16`
| Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
vllm/v1/attention/backend.py
View file @
c5362c73
...
@@ -252,7 +252,7 @@ class AttentionBackend(ABC):
...
@@ -252,7 +252,7 @@ class AttentionBackend(ABC):
else
:
else
:
invalid_reasons
.
append
(
"non-MLA not supported"
)
invalid_reasons
.
append
(
"non-MLA not supported"
)
if
has_sink
and
not
cls
.
supports_sink
():
if
has_sink
and
not
cls
.
supports_sink
():
invalid_reasons
.
append
(
"
sink setting
not supported"
)
invalid_reasons
.
append
(
"
attention sinks
not supported"
)
if
use_sparse
!=
cls
.
is_sparse
():
if
use_sparse
!=
cls
.
is_sparse
():
if
use_sparse
:
if
use_sparse
:
invalid_reasons
.
append
(
"sparse not supported"
)
invalid_reasons
.
append
(
"sparse not supported"
)
...
...
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
View file @
c5362c73
...
@@ -8,6 +8,7 @@ import torch
...
@@ -8,6 +8,7 @@ import torch
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.cache
import
CacheDType
from
vllm.model_executor.layers.attention.mla_attention
import
(
from
vllm.model_executor.layers.attention.mla_attention
import
(
MLACommonBackend
,
MLACommonBackend
,
MLACommonDecodeMetadata
,
MLACommonDecodeMetadata
,
...
@@ -21,6 +22,15 @@ from vllm.v1.kv_cache_interface import AttentionSpec
...
@@ -21,6 +22,15 @@ from vllm.v1.kv_cache_interface import AttentionSpec
class
AiterMLABackend
(
MLACommonBackend
):
class
AiterMLABackend
(
MLACommonBackend
):
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
"auto"
,
"bfloat16"
,
"fp8"
,
"fp8_e4m3"
,
"fp8_e5m2"
,
]
@
staticmethod
@
staticmethod
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
return
[
1
]
return
[
1
]
...
...
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
View file @
c5362c73
...
@@ -9,6 +9,7 @@ import torch
...
@@ -9,6 +9,7 @@ import torch
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.cache
import
CacheDType
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention.mla_attention
import
(
from
vllm.model_executor.layers.attention.mla_attention
import
(
get_mla_dims
,
get_mla_dims
,
...
@@ -21,6 +22,7 @@ from vllm.v1.attention.backend import (
...
@@ -21,6 +22,7 @@ from vllm.v1.attention.backend import (
AttentionMetadata
,
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
CommonAttentionMetadata
,
MultipleOf
,
SparseMLAAttentionImpl
,
SparseMLAAttentionImpl
,
)
)
from
vllm.v1.attention.backends.mla.flashmla_sparse
import
(
from
vllm.v1.attention.backends.mla.flashmla_sparse
import
(
...
@@ -77,7 +79,15 @@ def fetch_id_to_ragged_triton(
...
@@ -77,7 +79,15 @@ def fetch_id_to_ragged_triton(
class
ROCMAiterMLASparseBackend
(
AttentionBackend
):
class
ROCMAiterMLASparseBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
accept_output_buffer
:
bool
=
True
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
bfloat16
]
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
"auto"
,
"bfloat16"
,
]
@
staticmethod
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
return
[
1
]
@
staticmethod
@
staticmethod
def
get_name
()
->
str
:
def
get_name
()
->
str
:
...
@@ -105,10 +115,6 @@ class ROCMAiterMLASparseBackend(AttentionBackend):
...
@@ -105,10 +115,6 @@ class ROCMAiterMLASparseBackend(AttentionBackend):
)
->
tuple
[
int
,
...]:
)
->
tuple
[
int
,
...]:
return
(
num_blocks
,
block_size
,
head_size
)
return
(
num_blocks
,
block_size
,
head_size
)
@
classmethod
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
return
[
576
]
@
classmethod
@
classmethod
def
is_mla
(
cls
)
->
bool
:
def
is_mla
(
cls
)
->
bool
:
return
True
return
True
...
@@ -117,11 +123,6 @@ class ROCMAiterMLASparseBackend(AttentionBackend):
...
@@ -117,11 +123,6 @@ class ROCMAiterMLASparseBackend(AttentionBackend):
def
is_sparse
(
cls
)
->
bool
:
def
is_sparse
(
cls
)
->
bool
:
return
True
return
True
@
classmethod
def
supports_block_size
(
cls
,
block_size
:
int
|
None
)
->
bool
:
# The only supported block_size is 1
return
block_size
is
None
or
block_size
==
1
@
dataclass
@
dataclass
class
ROCMAiterMLASparseMetadata
(
AttentionMetadata
):
class
ROCMAiterMLASparseMetadata
(
AttentionMetadata
):
...
...
vllm/v1/attention/backends/mla/triton_mla.py
View file @
c5362c73
...
@@ -45,11 +45,6 @@ class TritonMLABackend(MLACommonBackend):
...
@@ -45,11 +45,6 @@ class TritonMLABackend(MLACommonBackend):
def
supports_compute_capability
(
cls
,
capability
:
DeviceCapability
)
->
bool
:
def
supports_compute_capability
(
cls
,
capability
:
DeviceCapability
)
->
bool
:
return
True
return
True
@
classmethod
def
supports_block_size
(
cls
,
block_size
:
int
|
None
)
->
bool
:
# The only unsupported block_size is 1
return
block_size
is
None
or
block_size
!=
1
class
TritonMLAImpl
(
MLACommonImpl
[
MLACommonMetadata
]):
class
TritonMLAImpl
(
MLACommonImpl
[
MLACommonMetadata
]):
can_return_lse_for_decode
:
bool
=
True
can_return_lse_for_decode
:
bool
=
True
...
...
vllm/v1/attention/backends/rocm_aiter_fa.py
View file @
c5362c73
...
@@ -9,6 +9,7 @@ import torch
...
@@ -9,6 +9,7 @@ import torch
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.config.cache
import
CacheDType
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -732,6 +733,13 @@ class AiterFlashAttentionMetadataBuilder(
...
@@ -732,6 +733,13 @@ class AiterFlashAttentionMetadataBuilder(
class
AiterFlashAttentionBackend
(
AttentionBackend
):
class
AiterFlashAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
accept_output_buffer
:
bool
=
True
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_dtypes
:
ClassVar
[
list
[
torch
.
dtype
]]
=
[
torch
.
float16
,
torch
.
bfloat16
]
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
"auto"
,
"bfloat16"
,
"fp8"
,
"fp8_e4m3"
,
"fp8_e5m2"
,
]
@
staticmethod
@
staticmethod
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
...
...
vllm/v1/attention/backends/rocm_aiter_unified_attn.py
View file @
c5362c73
...
@@ -11,7 +11,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -11,7 +11,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey
,
QuantKey
,
kFp8StaticTensorSym
,
kFp8StaticTensorSym
,
)
)
from
vllm.v1.attention.backend
import
AttentionLayer
,
AttentionType
from
vllm.v1.attention.backend
import
AttentionLayer
,
AttentionType
,
MultipleOf
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.rocm_attn
import
(
from
vllm.v1.attention.backends.rocm_attn
import
(
RocmAttentionBackend
,
RocmAttentionBackend
,
...
@@ -25,6 +25,22 @@ logger = init_logger(__name__)
...
@@ -25,6 +25,22 @@ logger = init_logger(__name__)
class
RocmAiterUnifiedAttentionBackend
(
RocmAttentionBackend
):
class
RocmAiterUnifiedAttentionBackend
(
RocmAttentionBackend
):
accept_output_buffer
:
bool
=
True
accept_output_buffer
:
bool
=
True
@
staticmethod
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
return
[
MultipleOf
(
16
)]
@
classmethod
def
supports_head_size
(
cls
,
head_size
:
int
)
->
bool
:
return
head_size
>=
32
@
classmethod
def
supports_mm_prefix
(
cls
)
->
bool
:
return
True
@
classmethod
def
supports_sink
(
cls
)
->
bool
:
return
True
forward_includes_kv_cache_update
:
bool
=
False
forward_includes_kv_cache_update
:
bool
=
False
@
staticmethod
@
staticmethod
...
...
vllm/v1/attention/backends/rocm_attn.py
View file @
c5362c73
...
@@ -9,6 +9,7 @@ import torch
...
@@ -9,6 +9,7 @@ import torch
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.cache
import
CacheDType
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
QuantKey
,
...
@@ -163,6 +164,13 @@ class RocmAttentionBackend(AttentionBackend):
...
@@ -163,6 +164,13 @@ class RocmAttentionBackend(AttentionBackend):
torch
.
bfloat16
,
torch
.
bfloat16
,
torch
.
float32
,
torch
.
float32
,
]
]
supported_kv_cache_dtypes
:
ClassVar
[
list
[
CacheDType
]]
=
[
"auto"
,
"bfloat16"
,
"fp8"
,
"fp8_e4m3"
,
"fp8_e5m2"
,
]
@
staticmethod
@
staticmethod
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
...
@@ -185,15 +193,12 @@ class RocmAttentionBackend(AttentionBackend):
...
@@ -185,15 +193,12 @@ class RocmAttentionBackend(AttentionBackend):
return
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
224
,
256
]
return
[
32
,
64
,
80
,
96
,
128
,
160
,
192
,
224
,
256
]
@
classmethod
@
classmethod
def
validate_head_size
(
cls
,
head_size
:
int
)
->
None
:
def
supports_mm_prefix
(
cls
)
->
bool
:
if
not
cls
.
supports_head_size
(
head_size
):
return
True
attn_type
=
cls
.
__name__
.
removesuffix
(
"Backend"
)
raise
ValueError
(
@
classmethod
f
"Head size
{
head_size
}
is not supported by
{
attn_type
}
. "
def
supports_sink
(
cls
)
->
bool
:
f
"Supported head sizes are:
{
cls
.
get_supported_head_sizes
()
}
. "
return
True
"Set --attention-backend=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
forward_includes_kv_cache_update
:
bool
=
False
forward_includes_kv_cache_update
:
bool
=
False
...
@@ -275,8 +280,6 @@ class RocmAttentionImpl(AttentionImpl):
...
@@ -275,8 +280,6 @@ class RocmAttentionImpl(AttentionImpl):
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
RocmAttentionBackend
.
validate_head_size
(
head_size
)
self
.
fp8_dtype
=
current_platform
.
fp8_dtype
()
self
.
fp8_dtype
=
current_platform
.
fp8_dtype
()
self
.
sinks
=
sinks
self
.
sinks
=
sinks
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment