Unverified Commit 6b6e7487 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Remove q concat in FA3 backend for DeepSeek decode (#5638)

parent 91732486
......@@ -62,6 +62,7 @@ class AttentionBackend(ABC):
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
**kwargs,
):
"""Run forward on an attention layer."""
if forward_batch.forward_mode.is_decode():
......@@ -72,6 +73,7 @@ class AttentionBackend(ABC):
layer,
forward_batch,
save_kv_cache=save_kv_cache,
**kwargs,
)
else:
return self.forward_extend(
......@@ -81,6 +83,7 @@ class AttentionBackend(ABC):
layer,
forward_batch,
save_kv_cache=save_kv_cache,
**kwargs,
)
def forward_decode(
......
......@@ -623,6 +623,8 @@ class FlashAttentionBackend(AttentionBackend):
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
):
if k is not None:
assert v is not None
......@@ -815,9 +817,15 @@ class FlashAttentionBackend(AttentionBackend):
c_kv_cache = c_kv.view(
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
)
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
if q_rope is not None:
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
else:
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
result = flash_attn_with_kvcache(
q=q_rope,
......@@ -877,6 +885,8 @@ class FlashAttentionBackend(AttentionBackend):
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if k is not None:
assert v is not None
......@@ -1047,9 +1057,15 @@ class FlashAttentionBackend(AttentionBackend):
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
)
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
if q_rope is not None:
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
else:
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
max_seqlen_q = metadata.max_seq_len_q
result = flash_attn_with_kvcache(
......
......@@ -87,6 +87,7 @@ class RadixAttention(nn.Module):
v,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
**kwargs,
):
if k is not None:
# For cross-layer sharing, kv can be None
......@@ -95,5 +96,11 @@ class RadixAttention(nn.Module):
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
return forward_batch.attn_backend.forward(
q, k, v, self, forward_batch, save_kv_cache
q,
k,
v,
self,
forward_batch,
save_kv_cache,
**kwargs,
)
......@@ -751,10 +751,15 @@ class DeepseekV2AttentionMLA(nn.Module):
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q = torch.cat([q_nope_out, q_pe], dim=-1)
k = torch.cat([k_nope, k_pe], dim=-1)
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
if self.attention_backend == "fa3":
attn_output = self.attn_mqa(
q_nope_out, k, k_nope, forward_batch, q_rope=q_pe
)
else:
q = torch.cat([q_nope_out, q_pe], dim=-1)
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
if self.use_deep_gemm_bmm:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment