Unverified Commit dc937175 authored by Pleaplusone's avatar Pleaplusone Committed by GitHub
Browse files

[ROCm][Perf] New design on ROCm AITER MHA backend Implementation (#25763)


Signed-off-by: default avatarganyi <ygan@amd.com>
parent 2f1cc8ce
This diff is collapsed.
......@@ -728,6 +728,73 @@ def subclass_attention_backend(
)
def split_decodes_prefills_and_extends(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,
) -> tuple[int, int, int, int, int, int]:
"""
Assuming a reordered batch, finds the boundary between prefill and decode
requests.
Args:
common_attn_metadata: CommonAttentionMetadata object containing the
batch metadata.
decode_threshold: The maximum query length to be considered a decode.
Returns:
num_decodes: The number of decode requests.
num_extends: The number of extend requests.
num_prefills: The number of prefill requests.
num_decode_tokens: The number of tokens in the decode requests.
num_extend_tokens: The number of tokens in the extend requests.
num_prefill_tokens: The number of tokens in the prefill requests.
"""
max_query_len = common_attn_metadata.max_query_len
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens_cpu
if max_query_len <= decode_threshold:
return num_reqs, 0, 0, num_tokens, 0, 0
query_lens = query_start_loc[1:] - query_start_loc[:-1]
is_prefill_or_extend = query_lens > decode_threshold
is_prefill = (seq_lens == query_lens) & is_prefill_or_extend
first_extend = is_prefill_or_extend.int().argmax(dim=-1).item()
first_prefill = is_prefill.int().argmax(dim=-1).item()
num_decodes = first_extend
num_decode_tokens = query_start_loc[first_extend].item()
if not torch.any(is_prefill_or_extend):
return (num_decodes, 0, 0, num_decode_tokens, 0, 0)
num_prefills_or_extends = num_reqs - num_decodes
num_prefill_or_extend_tokens = num_tokens - num_decode_tokens
if not torch.any(is_prefill):
return (
num_decodes,
num_prefills_or_extends,
0,
num_decode_tokens,
num_prefill_or_extend_tokens,
0,
)
num_extends = first_prefill - num_decodes
num_prefills = num_reqs - first_prefill
num_prefill_tokens = num_tokens - query_start_loc[first_prefill]
num_extend_tokens = num_prefill_or_extend_tokens - num_prefill_tokens
return (
num_decodes,
num_extends,
num_prefills,
num_decode_tokens,
num_extend_tokens,
num_prefill_tokens,
)
def split_decodes_and_prefills(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment