Unverified Commit f95c11a8 authored by hangy-amd's avatar hangy-amd Committed by GitHub
Browse files

[Feat] dflash support for ROCm (#39703)


Signed-off-by: default avatarHang Yang <hangy@amd.com>
parent 257015d5
...@@ -389,6 +389,7 @@ class AiterFlashAttentionMetadata: ...@@ -389,6 +389,7 @@ class AiterFlashAttentionMetadata:
seq_lens: torch.Tensor seq_lens: torch.Tensor
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
block_table: torch.Tensor block_table: torch.Tensor
causal: bool
# prefill and decode split # prefill and decode split
num_decodes: int num_decodes: int
...@@ -676,6 +677,7 @@ class AiterFlashAttentionMetadataBuilder( ...@@ -676,6 +677,7 @@ class AiterFlashAttentionMetadataBuilder(
max_seq_len=common_attn_metadata.max_seq_len, max_seq_len=common_attn_metadata.max_seq_len,
seq_lens=common_attn_metadata.seq_lens, seq_lens=common_attn_metadata.seq_lens,
block_table=common_attn_metadata.block_table_tensor, block_table=common_attn_metadata.block_table_tensor,
causal=common_attn_metadata.causal,
slot_mapping=common_attn_metadata.slot_mapping, slot_mapping=common_attn_metadata.slot_mapping,
num_decodes=num_decodes, num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
...@@ -724,6 +726,7 @@ class AiterFlashAttentionMetadataBuilder( ...@@ -724,6 +726,7 @@ class AiterFlashAttentionMetadataBuilder(
max_seq_len=common_attn_metadata.max_seq_len, max_seq_len=common_attn_metadata.max_seq_len,
seq_lens=common_attn_metadata.seq_lens, seq_lens=common_attn_metadata.seq_lens,
block_table=common_attn_metadata.block_table_tensor, block_table=common_attn_metadata.block_table_tensor,
causal=common_attn_metadata.causal,
slot_mapping=common_attn_metadata.slot_mapping, slot_mapping=common_attn_metadata.slot_mapping,
num_decodes=num_reqs, num_decodes=num_reqs,
num_decode_tokens=num_tokens, num_decode_tokens=num_tokens,
...@@ -808,6 +811,10 @@ class AiterFlashAttentionBackend(AttentionBackend): ...@@ -808,6 +811,10 @@ class AiterFlashAttentionBackend(AttentionBackend):
# more reliable. # more reliable.
return on_mi3xx() return on_mi3xx()
@classmethod
def supports_non_causal(cls) -> bool:
return True
class AiterFlashAttentionImpl(AttentionImpl): class AiterFlashAttentionImpl(AttentionImpl):
def __init__( def __init__(
...@@ -1122,7 +1129,7 @@ class AiterFlashAttentionImpl(AttentionImpl): ...@@ -1122,7 +1129,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
min_seqlen_q=1, min_seqlen_q=1,
dropout_p=0.0, dropout_p=0.0,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=attn_metadata.causal,
window_size=self.sliding_window, window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
out=output_actual_tokens[num_decode_tokens + num_extend_tokens :], out=output_actual_tokens[num_decode_tokens + num_extend_tokens :],
...@@ -1170,39 +1177,78 @@ class AiterFlashAttentionImpl(AttentionImpl): ...@@ -1170,39 +1177,78 @@ class AiterFlashAttentionImpl(AttentionImpl):
assert attn_metadata.decode_metadata is not None assert attn_metadata.decode_metadata is not None
decode_max_query_len = attn_metadata.decode_metadata.max_query_len decode_max_query_len = attn_metadata.decode_metadata.max_query_len
# Use unified_attention for speculative decoding (multi-token) # Multi-token speculative decode path.
if decode_max_query_len > 1: if decode_max_query_len > 1:
assert not rocm_aiter_ops.is_shuffle_kv_cache_enabled(), ( assert not rocm_aiter_ops.is_shuffle_kv_cache_enabled(), (
"Shuffle KV cache layout is not supported with " "Shuffle KV cache layout is not supported with "
"speculative decoding (multi-token decode)." "speculative decoding (multi-token decode)."
) )
from aiter.ops.triton.unified_attention import ( if not attn_metadata.causal:
unified_attention, from aiter.ops.triton.attention.mha_v3 import (
) flash_attn_with_kvcache,
)
descale_shape = (
num_decodes, descale_shape = (num_decodes, key_cache.shape[2])
key_cache.shape[2], decode_query = query[:num_decode_tokens].reshape(
) num_decodes,
unified_attention( decode_max_query_len,
q=query[:num_decode_tokens], query.shape[1],
k=key_cache, query.shape[2],
v=value_cache, )
out=output[:num_decode_tokens], decode_out = flash_attn_with_kvcache(
cu_seqlens_q=attn_metadata.query_start_loc[: num_decodes + 1], q=decode_query,
max_seqlen_q=decode_max_query_len, k_cache=key_cache,
seqused_k=attn_metadata.seq_lens[:num_decodes], v_cache=value_cache,
max_seqlen_k=attn_metadata.max_seq_len, cache_seqlens=attn_metadata.seq_lens[:num_decodes],
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes, window_size=self.sliding_window,
window_size=self.sliding_window, softcap=self.logits_soft_cap,
block_table=attn_metadata.block_table[:num_decodes], q_descale=None,
softcap=self.logits_soft_cap, k_descale=layer._k_scale.expand(descale_shape),
q_descale=None, v_descale=layer._v_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape), page_table=attn_metadata.block_table[:num_decodes],
v_descale=layer._v_scale.expand(descale_shape), )
) output[:num_decode_tokens].copy_(
decode_out.reshape(
num_decode_tokens,
query.shape[1],
query.shape[2],
)
)
else:
# Non-uniform query lengths can appear in real serving
# traffic (e.g. mixed datasets). Fall back to varlen
# unified_attention instead of asserting.
from aiter.ops.triton.unified_attention import (
unified_attention,
)
descale_shape = (
num_decodes,
key_cache.shape[2],
)
unified_attention(
q=query[:num_decode_tokens],
k=key_cache,
v=value_cache,
out=output[:num_decode_tokens],
cu_seqlens_q=attn_metadata.query_start_loc[
: num_decodes + 1
],
max_seqlen_q=decode_max_query_len,
seqused_k=attn_metadata.seq_lens[:num_decodes],
max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=attn_metadata.block_table[:num_decodes],
softcap=self.logits_soft_cap,
q_descale=None,
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
return return
# The ll4mi kernel in paged_attention_v1 requires # The ll4mi kernel in paged_attention_v1 requires
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment