Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dc937175
Unverified
Commit
dc937175
authored
Nov 05, 2025
by
Pleaplusone
Committed by
GitHub
Nov 04, 2025
Browse files
[ROCm][Perf] New design on ROCm AITER MHA backend Implementation (#25763)
Signed-off-by:
ganyi
<
ygan@amd.com
>
parent
2f1cc8ce
Changes
2
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
593 additions
and
275 deletions
+593
-275
vllm/v1/attention/backends/rocm_aiter_fa.py
vllm/v1/attention/backends/rocm_aiter_fa.py
+526
-275
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+67
-0
No files found.
vllm/v1/attention/backends/rocm_aiter_fa.py
View file @
dc937175
This diff is collapsed.
Click to expand it.
vllm/v1/attention/backends/utils.py
View file @
dc937175
...
...
@@ -728,6 +728,73 @@ def subclass_attention_backend(
)
def
split_decodes_prefills_and_extends
(
common_attn_metadata
:
CommonAttentionMetadata
,
decode_threshold
:
int
=
1
,
)
->
tuple
[
int
,
int
,
int
,
int
,
int
,
int
]:
"""
Assuming a reordered batch, finds the boundary between prefill and decode
requests.
Args:
common_attn_metadata: CommonAttentionMetadata object containing the
batch metadata.
decode_threshold: The maximum query length to be considered a decode.
Returns:
num_decodes: The number of decode requests.
num_extends: The number of extend requests.
num_prefills: The number of prefill requests.
num_decode_tokens: The number of tokens in the decode requests.
num_extend_tokens: The number of tokens in the extend requests.
num_prefill_tokens: The number of tokens in the prefill requests.
"""
max_query_len
=
common_attn_metadata
.
max_query_len
num_reqs
=
common_attn_metadata
.
num_reqs
num_tokens
=
common_attn_metadata
.
num_actual_tokens
query_start_loc
=
common_attn_metadata
.
query_start_loc_cpu
seq_lens
=
common_attn_metadata
.
seq_lens_cpu
if
max_query_len
<=
decode_threshold
:
return
num_reqs
,
0
,
0
,
num_tokens
,
0
,
0
query_lens
=
query_start_loc
[
1
:]
-
query_start_loc
[:
-
1
]
is_prefill_or_extend
=
query_lens
>
decode_threshold
is_prefill
=
(
seq_lens
==
query_lens
)
&
is_prefill_or_extend
first_extend
=
is_prefill_or_extend
.
int
().
argmax
(
dim
=-
1
).
item
()
first_prefill
=
is_prefill
.
int
().
argmax
(
dim
=-
1
).
item
()
num_decodes
=
first_extend
num_decode_tokens
=
query_start_loc
[
first_extend
].
item
()
if
not
torch
.
any
(
is_prefill_or_extend
):
return
(
num_decodes
,
0
,
0
,
num_decode_tokens
,
0
,
0
)
num_prefills_or_extends
=
num_reqs
-
num_decodes
num_prefill_or_extend_tokens
=
num_tokens
-
num_decode_tokens
if
not
torch
.
any
(
is_prefill
):
return
(
num_decodes
,
num_prefills_or_extends
,
0
,
num_decode_tokens
,
num_prefill_or_extend_tokens
,
0
,
)
num_extends
=
first_prefill
-
num_decodes
num_prefills
=
num_reqs
-
first_prefill
num_prefill_tokens
=
num_tokens
-
query_start_loc
[
first_prefill
]
num_extend_tokens
=
num_prefill_or_extend_tokens
-
num_prefill_tokens
return
(
num_decodes
,
num_extends
,
num_prefills
,
num_decode_tokens
,
num_extend_tokens
,
num_prefill_tokens
,
)
def
split_decodes_and_prefills
(
common_attn_metadata
:
CommonAttentionMetadata
,
decode_threshold
:
int
=
1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment