Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f95c11a8
Unverified
Commit
f95c11a8
authored
Apr 21, 2026
by
hangy-amd
Committed by
GitHub
Apr 21, 2026
Browse files
[Feat] dflash support for ROCm (#39703)
Signed-off-by:
Hang Yang
<
hangy@amd.com
>
parent
257015d5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
75 additions
and
29 deletions
+75
-29
vllm/v1/attention/backends/rocm_aiter_fa.py
vllm/v1/attention/backends/rocm_aiter_fa.py
+75
-29
No files found.
vllm/v1/attention/backends/rocm_aiter_fa.py
View file @
f95c11a8
...
@@ -389,6 +389,7 @@ class AiterFlashAttentionMetadata:
...
@@ -389,6 +389,7 @@ class AiterFlashAttentionMetadata:
seq_lens
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
block_table
:
torch
.
Tensor
block_table
:
torch
.
Tensor
causal
:
bool
# prefill and decode split
# prefill and decode split
num_decodes
:
int
num_decodes
:
int
...
@@ -676,6 +677,7 @@ class AiterFlashAttentionMetadataBuilder(
...
@@ -676,6 +677,7 @@ class AiterFlashAttentionMetadataBuilder(
max_seq_len
=
common_attn_metadata
.
max_seq_len
,
max_seq_len
=
common_attn_metadata
.
max_seq_len
,
seq_lens
=
common_attn_metadata
.
seq_lens
,
seq_lens
=
common_attn_metadata
.
seq_lens
,
block_table
=
common_attn_metadata
.
block_table_tensor
,
block_table
=
common_attn_metadata
.
block_table_tensor
,
causal
=
common_attn_metadata
.
causal
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
num_decodes
=
num_decodes
,
num_decodes
=
num_decodes
,
num_decode_tokens
=
num_decode_tokens
,
num_decode_tokens
=
num_decode_tokens
,
...
@@ -724,6 +726,7 @@ class AiterFlashAttentionMetadataBuilder(
...
@@ -724,6 +726,7 @@ class AiterFlashAttentionMetadataBuilder(
max_seq_len
=
common_attn_metadata
.
max_seq_len
,
max_seq_len
=
common_attn_metadata
.
max_seq_len
,
seq_lens
=
common_attn_metadata
.
seq_lens
,
seq_lens
=
common_attn_metadata
.
seq_lens
,
block_table
=
common_attn_metadata
.
block_table_tensor
,
block_table
=
common_attn_metadata
.
block_table_tensor
,
causal
=
common_attn_metadata
.
causal
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
num_decodes
=
num_reqs
,
num_decodes
=
num_reqs
,
num_decode_tokens
=
num_tokens
,
num_decode_tokens
=
num_tokens
,
...
@@ -808,6 +811,10 @@ class AiterFlashAttentionBackend(AttentionBackend):
...
@@ -808,6 +811,10 @@ class AiterFlashAttentionBackend(AttentionBackend):
# more reliable.
# more reliable.
return
on_mi3xx
()
return
on_mi3xx
()
@
classmethod
def
supports_non_causal
(
cls
)
->
bool
:
return
True
class
AiterFlashAttentionImpl
(
AttentionImpl
):
class
AiterFlashAttentionImpl
(
AttentionImpl
):
def
__init__
(
def
__init__
(
...
@@ -1122,7 +1129,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
...
@@ -1122,7 +1129,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
min_seqlen_q
=
1
,
min_seqlen_q
=
1
,
dropout_p
=
0.0
,
dropout_p
=
0.0
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
attn_metadata
.
causal
,
window_size
=
self
.
sliding_window
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
alibi_slopes
=
self
.
alibi_slopes
,
out
=
output_actual_tokens
[
num_decode_tokens
+
num_extend_tokens
:],
out
=
output_actual_tokens
[
num_decode_tokens
+
num_extend_tokens
:],
...
@@ -1170,39 +1177,78 @@ class AiterFlashAttentionImpl(AttentionImpl):
...
@@ -1170,39 +1177,78 @@ class AiterFlashAttentionImpl(AttentionImpl):
assert
attn_metadata
.
decode_metadata
is
not
None
assert
attn_metadata
.
decode_metadata
is
not
None
decode_max_query_len
=
attn_metadata
.
decode_metadata
.
max_query_len
decode_max_query_len
=
attn_metadata
.
decode_metadata
.
max_query_len
#
Use unified_attention for
speculative decod
ing (multi-token)
#
Multi-token
speculative decod
e path.
if
decode_max_query_len
>
1
:
if
decode_max_query_len
>
1
:
assert
not
rocm_aiter_ops
.
is_shuffle_kv_cache_enabled
(),
(
assert
not
rocm_aiter_ops
.
is_shuffle_kv_cache_enabled
(),
(
"Shuffle KV cache layout is not supported with "
"Shuffle KV cache layout is not supported with "
"speculative decoding (multi-token decode)."
"speculative decoding (multi-token decode)."
)
)
from
aiter.ops.triton.unified_attention
import
(
if
not
attn_metadata
.
causal
:
unified_attention
,
from
aiter.ops.triton.attention.mha_v3
import
(
)
flash_attn_with_kvcache
,
)
descale_shape
=
(
num_decodes
,
descale_shape
=
(
num_decodes
,
key_cache
.
shape
[
2
])
key_cache
.
shape
[
2
],
decode_query
=
query
[:
num_decode_tokens
].
reshape
(
)
num_decodes
,
unified_attention
(
decode_max_query_len
,
q
=
query
[:
num_decode_tokens
],
query
.
shape
[
1
],
k
=
key_cache
,
query
.
shape
[
2
],
v
=
value_cache
,
)
out
=
output
[:
num_decode_tokens
],
decode_out
=
flash_attn_with_kvcache
(
cu_seqlens_q
=
attn_metadata
.
query_start_loc
[:
num_decodes
+
1
],
q
=
decode_query
,
max_seqlen_q
=
decode_max_query_len
,
k_cache
=
key_cache
,
seqused_k
=
attn_metadata
.
seq_lens
[:
num_decodes
],
v_cache
=
value_cache
,
max_seqlen_k
=
attn_metadata
.
max_seq_len
,
cache_seqlens
=
attn_metadata
.
seq_lens
[:
num_decodes
],
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
attn_metadata
.
causal
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
sliding_window
,
window_size
=
self
.
sliding_window
,
softcap
=
self
.
logits_soft_cap
,
block_table
=
attn_metadata
.
block_table
[:
num_decodes
],
q_descale
=
None
,
softcap
=
self
.
logits_soft_cap
,
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
q_descale
=
None
,
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
page_table
=
attn_metadata
.
block_table
[:
num_decodes
],
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
)
output
[:
num_decode_tokens
].
copy_
(
decode_out
.
reshape
(
num_decode_tokens
,
query
.
shape
[
1
],
query
.
shape
[
2
],
)
)
else
:
# Non-uniform query lengths can appear in real serving
# traffic (e.g. mixed datasets). Fall back to varlen
# unified_attention instead of asserting.
from
aiter.ops.triton.unified_attention
import
(
unified_attention
,
)
descale_shape
=
(
num_decodes
,
key_cache
.
shape
[
2
],
)
unified_attention
(
q
=
query
[:
num_decode_tokens
],
k
=
key_cache
,
v
=
value_cache
,
out
=
output
[:
num_decode_tokens
],
cu_seqlens_q
=
attn_metadata
.
query_start_loc
[
:
num_decodes
+
1
],
max_seqlen_q
=
decode_max_query_len
,
seqused_k
=
attn_metadata
.
seq_lens
[:
num_decodes
],
max_seqlen_k
=
attn_metadata
.
max_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
sliding_window
,
block_table
=
attn_metadata
.
block_table
[:
num_decodes
],
softcap
=
self
.
logits_soft_cap
,
q_descale
=
None
,
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
return
return
# The ll4mi kernel in paged_attention_v1 requires
# The ll4mi kernel in paged_attention_v1 requires
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment