Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2aab2bb5
Unverified
Commit
2aab2bb5
authored
Feb 20, 2026
by
jennyyyyzhen
Committed by
GitHub
Feb 20, 2026
Browse files
[ROCM] Optimize ROCM_AITER_FA spec decode eagle performance (#34541)
Signed-off-by:
jennyyyyzhen
<
yzhen@hmc.edu
>
parent
54254f7a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
50 additions
and
2 deletions
+50
-2
vllm/v1/attention/backends/rocm_aiter_fa.py
vllm/v1/attention/backends/rocm_aiter_fa.py
+50
-2
No files found.
vllm/v1/attention/backends/rocm_aiter_fa.py
View file @
2aab2bb5
...
...
@@ -396,8 +396,7 @@ class AiterFlashAttentionMetadata:
class
AiterFlashAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
AiterFlashAttentionMetadata
]
):
_cudagraph_support
=
AttentionCGSupport
.
UNIFORM_SINGLE_TOKEN_DECODE
reorder_batch_threshold
:
int
=
1
_cudagraph_support
=
AttentionCGSupport
.
UNIFORM_BATCH
def
__init__
(
self
,
...
...
@@ -422,6 +421,7 @@ class AiterFlashAttentionMetadataBuilder(
# populated on first build() call.
self
.
aot_sliding_window
:
tuple
[
int
,
int
]
|
None
=
None
self
.
total_tokens
:
int
=
0
self
.
_init_reorder_batch_threshold
(
1
,
supports_spec_as_decode
=
True
)
sliding_window_configs
:
set
[
tuple
[
int
,
int
]
|
None
]
=
set
()
layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
)
...
...
@@ -466,6 +466,7 @@ class AiterFlashAttentionMetadataBuilder(
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
)
->
"AiterFlashAttentionMetadata"
:
assert
self
.
reorder_batch_threshold
is
not
None
split_ret
=
split_decodes_prefills_and_extends
(
common_attn_metadata
,
decode_threshold
=
self
.
reorder_batch_threshold
,
...
...
@@ -677,6 +678,53 @@ class AiterFlashAttentionMetadataBuilder(
)
return
attn_metadata
def
build_for_drafting
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
,
draft_index
:
int
,
)
->
AiterFlashAttentionMetadata
:
"""
Build attention metadata for draft model without CPU-GPU sync.
During EAGLE drafting all requests are uniform decodes, so we can
skip split_decodes_prefills_and_extends() and avoid all .cpu() /
.item() calls that would otherwise break CUDA graph capture.
"""
num_reqs
=
common_attn_metadata
.
num_reqs
num_tokens
=
common_attn_metadata
.
num_actual_tokens
decode_metadata
=
AiterFlashAttentionDecodeMetadata
(
max_query_len
=
common_attn_metadata
.
max_query_len
,
min_query_len
=
common_attn_metadata
.
max_query_len
,
# uniform batch
max_seq_len
=
common_attn_metadata
.
max_seq_len
,
query_start_loc
=
common_attn_metadata
.
query_start_loc
,
)
return
AiterFlashAttentionMetadata
(
num_actual_tokens
=
num_tokens
,
num_actual_kv_tokens
=
0
,
# not used in unified_attention path
max_query_len
=
common_attn_metadata
.
max_query_len
,
query_start_loc
=
common_attn_metadata
.
query_start_loc
,
max_seq_len
=
common_attn_metadata
.
max_seq_len
,
seq_lens
=
common_attn_metadata
.
seq_lens
,
block_table
=
common_attn_metadata
.
block_table_tensor
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
num_decodes
=
num_reqs
,
num_decode_tokens
=
num_tokens
,
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_extends
=
0
,
num_extend_tokens
=
0
,
decode_metadata
=
decode_metadata
,
prefill_metadata
=
None
,
extend_metadata
=
None
,
use_cascade
=
False
,
common_prefix_len
=
0
,
total_tokens
=
self
.
total_tokens
,
k_scale
=
self
.
scale
,
v_scale
=
self
.
scale
,
)
def
use_cascade_attention
(
self
,
*
args
,
**
kwargs
)
->
bool
:
return
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment