Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
43cc5138
Unverified
Commit
43cc5138
authored
Mar 29, 2026
by
Andreas Karatzas
Committed by
GitHub
Mar 28, 2026
Browse files
[ROCm][CI] Fix cross-attention dispatch for encoder-decoder models (#38450)
Signed-off-by:
Andreas Karatzas
<
akaratza@amd.com
>
parent
5b8c30d6
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
90 additions
and
19 deletions
+90
-19
docs/design/attention_backends.md
docs/design/attention_backends.md
+2
-2
tests/entrypoints/openai/speech_to_text/test_transcription_validation_whisper.py
...i/speech_to_text/test_transcription_validation_whisper.py
+52
-3
tools/pre_commit/generate_attention_backend_docs.py
tools/pre_commit/generate_attention_backend_docs.py
+1
-1
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+22
-6
vllm/v1/attention/backends/rocm_aiter_fa.py
vllm/v1/attention/backends/rocm_aiter_fa.py
+6
-5
vllm/v1/attention/backends/rocm_attn.py
vllm/v1/attention/backends/rocm_attn.py
+7
-2
No files found.
docs/design/attention_backends.md
View file @
43cc5138
...
...
@@ -173,9 +173,9 @@ Priority is **1 = highest** (tried first).
|
`FLASH_ATTN`
| FA4
*
| fp16, bf16 |
`auto`
,
`float16`
,
`bfloat16`
| %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 |
|
`FLASH_ATTN_DIFFKV`
| | fp16, bf16 |
`auto`
| Any | Any | ❌ | ❌ | ✅ | Decoder | Any |
|
`FLEX_ATTENTION`
| | fp16, bf16, fp32 |
`auto`
,
`float16`
,
`bfloat16`
| Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
|
`ROCM_AITER_FA`
| | fp16, bf16 |
`auto`
,
`float16`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder
, Enc-Dec
| N/A |
|
`ROCM_AITER_FA`
| | fp16, bf16 |
`auto`
,
`float16`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
|
`ROCM_AITER_UNIFIED_ATTN`
| | fp16, bf16 |
`auto`
| %16 | Any | ✅ | ✅ | ❌ | All | N/A |
|
`ROCM_ATTN`
| | fp16, bf16, fp32 |
`auto`
,
`float16`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ❌ | ✅ | ❌ |
All
| N/A |
|
`ROCM_ATTN`
| | fp16, bf16, fp32 |
`auto`
,
`float16`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ❌ | ✅ | ❌ |
Decoder, Encoder, Encoder Only
| N/A |
|
`TREE_ATTN`
| | fp16, bf16 |
`auto`
,
`float16`
,
`bfloat16`
| %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
|
`TRITON_ATTN`
| | fp16, bf16, fp32 |
`auto`
,
`float16`
,
`bfloat16`
,
`fp8`
,
`fp8_e4m3`
,
`fp8_e5m2`
| %16 | Any | ✅ | ✅ | ❌ | All | Any |
...
...
tests/entrypoints/openai/speech_to_text/test_transcription_validation_whisper.py
View file @
43cc5138
...
...
@@ -14,13 +14,62 @@ import pytest_asyncio
import
soundfile
as
sf
from
tests.utils
import
RemoteOpenAIServer
from
vllm.platforms
import
current_platform
MODEL_NAME
=
"openai/whisper-large-v3-turbo"
# Disable prefix caching on ROCm to reduce non-determinism in
# streaming-vs-non-streaming comparisons.
_ROCM_ARGS
=
[
"--no-enable-prefix-caching"
]
if
current_platform
.
is_rocm
()
else
[]
@
pytest
.
fixture
(
scope
=
"module"
)
def
server
():
with
RemoteOpenAIServer
(
MODEL_NAME
,
[])
as
remote_server
:
def
_get_attention_backend_params
()
->
list
[
str
|
None
]:
"""Return attention backends to parametrize the server fixture with.
On ROCm, we test multiple backends explicitly:
- None: default auto-selection (ROCM_ATTN for decoder self-attention,
falls back to ROCM_AITER_UNIFIED_ATTN or TRITON_ATTN for
cross-attention since ROCM_ATTN doesn't support ENCODER_DECODER)
- TRITON_ATTN: always available on ROCm
- ROCM_AITER_UNIFIED_ATTN: only on gfx942/gfx950
On non-ROCm platforms, we just run with the default backend.
"""
try
:
from
vllm.platforms
import
current_platform
if
current_platform
.
is_rocm
():
backends
:
list
[
str
|
None
]
=
[
None
,
"TRITON_ATTN"
]
from
vllm.platforms.rocm
import
_ON_MI3XX
if
_ON_MI3XX
:
backends
.
append
(
"ROCM_AITER_UNIFIED_ATTN"
)
return
backends
except
Exception
:
pass
return
[
None
]
# Aiter backends need VLLM_ROCM_USE_AITER=1 (and MHA=1 for ROCM_AITER_FA)
# to be enabled in the server subprocess.
_AITER_ENV
=
{
"VLLM_ROCM_USE_AITER"
:
"1"
,
"VLLM_ROCM_USE_AITER_MHA"
:
"1"
,
}
_ATTN_BACKENDS
=
_get_attention_backend_params
()
_ATTN_IDS
=
[
b
or
"default"
for
b
in
_ATTN_BACKENDS
]
@
pytest
.
fixture
(
scope
=
"module"
,
params
=
_ATTN_BACKENDS
,
ids
=
_ATTN_IDS
)
def
server
(
request
):
args
=
[
*
_ROCM_ARGS
]
env_dict
=
None
if
request
.
param
is
not
None
:
args
+=
[
"--attention-backend"
,
request
.
param
]
if
"AITER"
in
request
.
param
:
env_dict
=
_AITER_ENV
with
RemoteOpenAIServer
(
MODEL_NAME
,
args
,
env_dict
=
env_dict
)
as
remote_server
:
yield
remote_server
...
...
tools/pre_commit/generate_attention_backend_docs.py
View file @
43cc5138
...
...
@@ -446,7 +446,7 @@ def parse_attention_types(node: ast.ClassDef) -> str:
if
not
types
:
return
"Decoder"
return
"All"
if
len
(
types
)
>=
3
else
", "
.
join
(
sorted
(
types
))
return
"All"
if
types
>=
set
(
type_map
.
values
())
else
", "
.
join
(
sorted
(
types
))
def
parse_impl_bool_attr
(
...
...
vllm/platforms/rocm.py
View file @
43cc5138
...
...
@@ -439,7 +439,10 @@ class RocmPlatform(Platform):
f
"this configuration. Reason:
{
invalid_reasons
}
"
)
else
:
logger
.
info
(
"Using %s backend."
,
selected_backend
)
logger
.
info_once
(
"Using %s backend (selected via --attention-backend)."
,
selected_backend
.
name
,
)
return
selected_backend
.
get_path
()
# No selected backend or the selected backend is invalid,
...
...
@@ -476,11 +479,24 @@ class RocmPlatform(Platform):
)
selected_index
=
sorted_indices
[
0
]
selected_backend
=
valid_backends_priorities
[
selected_index
][
0
]
valid_str
=
(
"["
+
", "
.
join
(
f
"'
{
b
[
0
].
name
}
'"
for
b
in
valid_backends_priorities
)
+
"]"
)
if
invalid_reasons
:
rejected_str
=
", "
.
join
(
b
.
name
for
b
in
invalid_reasons
)
logger
.
info
(
"Found incompatible backend(s) [%s] with %s. "
"Overriding with %s out of potential backends: %s."
,
rejected_str
,
attn_selector_config
.
attn_type
,
selected_backend
.
name
,
valid_str
,
)
else
:
logger
.
info_once
(
"Using %s
attention
backend out of potential backends: %s."
,
"Using %s backend out of potential backends: %s."
,
selected_backend
.
name
,
"["
+
", "
.
join
(
f
"'
{
b
[
0
].
name
}
'"
for
b
in
valid_backends_priorities
)
+
"]"
,
scope
=
"local"
,
valid_str
,
)
return
selected_backend
.
get_path
()
...
...
vllm/v1/attention/backends/rocm_aiter_fa.py
View file @
43cc5138
...
...
@@ -758,11 +758,12 @@ class AiterFlashAttentionBackend(AttentionBackend):
@
classmethod
def
supports_attn_type
(
cls
,
attn_type
:
str
)
->
bool
:
"""ROCM AITER FA supports decoder and encoder-decoder (cross) attention."""
return
attn_type
in
(
AttentionType
.
DECODER
,
AttentionType
.
ENCODER_DECODER
,
)
"""ENCODER_DECODER is not supported because the prefill path uses
flash_attn_varlen_func with cu_seqlens_k set to decoder
query_start_loc (not encoder seq lens) and causal=True, both of
which are incorrect for cross-attention layers.
"""
return
attn_type
in
(
AttentionType
.
DECODER
,)
@
staticmethod
def
get_supported_kernel_block_sizes
()
->
list
[
int
|
MultipleOf
]:
...
...
vllm/v1/attention/backends/rocm_attn.py
View file @
43cc5138
...
...
@@ -212,12 +212,17 @@ class RocmAttentionBackend(AttentionBackend):
@
classmethod
def
supports_attn_type
(
cls
,
attn_type
:
str
)
->
bool
:
"""RocmAttention supports all attention types."""
"""ENCODER_DECODER is not supported because
chunked_prefill_paged_decode's prefill kernel (context_attention_fwd)
assumes self-attention semantics: it treats passed K/V as new tokens
to mix with cached K/V. For cross-attention layers the encoder K/V
are already fully cached, so mixing them again produces incorrect
results when max_query_len > 1 (e.g. beam search).
"""
return
attn_type
in
(
AttentionType
.
DECODER
,
AttentionType
.
ENCODER
,
AttentionType
.
ENCODER_ONLY
,
AttentionType
.
ENCODER_DECODER
,
)
@
staticmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment