Unverified Commit f95c11a8 authored by hangy-amd's avatar hangy-amd Committed by GitHub
Browse files

[Feat] dflash support for ROCm (#39703)


Signed-off-by: default avatarHang Yang <hangy@amd.com>
parent 257015d5
......@@ -389,6 +389,7 @@ class AiterFlashAttentionMetadata:
seq_lens: torch.Tensor
slot_mapping: torch.Tensor
block_table: torch.Tensor
causal: bool
# prefill and decode split
num_decodes: int
......@@ -676,6 +677,7 @@ class AiterFlashAttentionMetadataBuilder(
max_seq_len=common_attn_metadata.max_seq_len,
seq_lens=common_attn_metadata.seq_lens,
block_table=common_attn_metadata.block_table_tensor,
causal=common_attn_metadata.causal,
slot_mapping=common_attn_metadata.slot_mapping,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
......@@ -724,6 +726,7 @@ class AiterFlashAttentionMetadataBuilder(
max_seq_len=common_attn_metadata.max_seq_len,
seq_lens=common_attn_metadata.seq_lens,
block_table=common_attn_metadata.block_table_tensor,
causal=common_attn_metadata.causal,
slot_mapping=common_attn_metadata.slot_mapping,
num_decodes=num_reqs,
num_decode_tokens=num_tokens,
......@@ -808,6 +811,10 @@ class AiterFlashAttentionBackend(AttentionBackend):
# more reliable.
return on_mi3xx()
@classmethod
def supports_non_causal(cls) -> bool:
return True
class AiterFlashAttentionImpl(AttentionImpl):
def __init__(
......@@ -1122,7 +1129,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
min_seqlen_q=1,
dropout_p=0.0,
softmax_scale=self.scale,
causal=True,
causal=attn_metadata.causal,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
out=output_actual_tokens[num_decode_tokens + num_extend_tokens :],
......@@ -1170,12 +1177,49 @@ class AiterFlashAttentionImpl(AttentionImpl):
assert attn_metadata.decode_metadata is not None
decode_max_query_len = attn_metadata.decode_metadata.max_query_len
# Use unified_attention for speculative decoding (multi-token)
# Multi-token speculative decode path.
if decode_max_query_len > 1:
assert not rocm_aiter_ops.is_shuffle_kv_cache_enabled(), (
"Shuffle KV cache layout is not supported with "
"speculative decoding (multi-token decode)."
)
if not attn_metadata.causal:
from aiter.ops.triton.attention.mha_v3 import (
flash_attn_with_kvcache,
)
descale_shape = (num_decodes, key_cache.shape[2])
decode_query = query[:num_decode_tokens].reshape(
num_decodes,
decode_max_query_len,
query.shape[1],
query.shape[2],
)
decode_out = flash_attn_with_kvcache(
q=decode_query,
k_cache=key_cache,
v_cache=value_cache,
cache_seqlens=attn_metadata.seq_lens[:num_decodes],
softmax_scale=self.scale,
causal=attn_metadata.causal,
window_size=self.sliding_window,
softcap=self.logits_soft_cap,
q_descale=None,
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
page_table=attn_metadata.block_table[:num_decodes],
)
output[:num_decode_tokens].copy_(
decode_out.reshape(
num_decode_tokens,
query.shape[1],
query.shape[2],
)
)
else:
# Non-uniform query lengths can appear in real serving
# traffic (e.g. mixed datasets). Fall back to varlen
# unified_attention instead of asserting.
from aiter.ops.triton.unified_attention import (
unified_attention,
)
......@@ -1189,7 +1233,9 @@ class AiterFlashAttentionImpl(AttentionImpl):
k=key_cache,
v=value_cache,
out=output[:num_decode_tokens],
cu_seqlens_q=attn_metadata.query_start_loc[: num_decodes + 1],
cu_seqlens_q=attn_metadata.query_start_loc[
: num_decodes + 1
],
max_seqlen_q=decode_max_query_len,
seqused_k=attn_metadata.seq_lens[:num_decodes],
max_seqlen_k=attn_metadata.max_seq_len,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment