Unverified Commit 6b6e7487 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Remove q concat in FA3 backend for DeepSeek decode (#5638)

parent 91732486
...@@ -62,6 +62,7 @@ class AttentionBackend(ABC): ...@@ -62,6 +62,7 @@ class AttentionBackend(ABC):
layer: RadixAttention, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache: bool = True, save_kv_cache: bool = True,
**kwargs,
): ):
"""Run forward on an attention layer.""" """Run forward on an attention layer."""
if forward_batch.forward_mode.is_decode(): if forward_batch.forward_mode.is_decode():
...@@ -72,6 +73,7 @@ class AttentionBackend(ABC): ...@@ -72,6 +73,7 @@ class AttentionBackend(ABC):
layer, layer,
forward_batch, forward_batch,
save_kv_cache=save_kv_cache, save_kv_cache=save_kv_cache,
**kwargs,
) )
else: else:
return self.forward_extend( return self.forward_extend(
...@@ -81,6 +83,7 @@ class AttentionBackend(ABC): ...@@ -81,6 +83,7 @@ class AttentionBackend(ABC):
layer, layer,
forward_batch, forward_batch,
save_kv_cache=save_kv_cache, save_kv_cache=save_kv_cache,
**kwargs,
) )
def forward_decode( def forward_decode(
......
...@@ -623,6 +623,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -623,6 +623,8 @@ class FlashAttentionBackend(AttentionBackend):
layer: RadixAttention, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache=True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
): ):
if k is not None: if k is not None:
assert v is not None assert v is not None
...@@ -815,9 +817,15 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -815,9 +817,15 @@ class FlashAttentionBackend(AttentionBackend):
c_kv_cache = c_kv.view( c_kv_cache = c_kv.view(
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
) )
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) if q_rope is not None:
q_nope = q_all[:, :, : layer.v_head_dim] q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_all[:, :, layer.v_head_dim :] q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
else:
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
result = flash_attn_with_kvcache( result = flash_attn_with_kvcache(
q=q_rope, q=q_rope,
...@@ -877,6 +885,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -877,6 +885,8 @@ class FlashAttentionBackend(AttentionBackend):
layer: RadixAttention, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache=True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if k is not None: if k is not None:
assert v is not None assert v is not None
...@@ -1047,9 +1057,15 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1047,9 +1057,15 @@ class FlashAttentionBackend(AttentionBackend):
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
) )
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) if q_rope is not None:
q_nope = q_all[:, :, : layer.v_head_dim] q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_all[:, :, layer.v_head_dim :] q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
else:
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
max_seqlen_q = metadata.max_seq_len_q max_seqlen_q = metadata.max_seq_len_q
result = flash_attn_with_kvcache( result = flash_attn_with_kvcache(
......
...@@ -87,6 +87,7 @@ class RadixAttention(nn.Module): ...@@ -87,6 +87,7 @@ class RadixAttention(nn.Module):
v, v,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache: bool = True, save_kv_cache: bool = True,
**kwargs,
): ):
if k is not None: if k is not None:
# For cross-layer sharing, kv can be None # For cross-layer sharing, kv can be None
...@@ -95,5 +96,11 @@ class RadixAttention(nn.Module): ...@@ -95,5 +96,11 @@ class RadixAttention(nn.Module):
v = v.view(-1, self.tp_v_head_num, self.v_head_dim) v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
return forward_batch.attn_backend.forward( return forward_batch.attn_backend.forward(
q, k, v, self, forward_batch, save_kv_cache q,
k,
v,
self,
forward_batch,
save_kv_cache,
**kwargs,
) )
...@@ -751,10 +751,15 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -751,10 +751,15 @@ class DeepseekV2AttentionMLA(nn.Module):
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q = torch.cat([q_nope_out, q_pe], dim=-1)
k = torch.cat([k_nope, k_pe], dim=-1) k = torch.cat([k_nope, k_pe], dim=-1)
attn_output = self.attn_mqa(q, k, k_nope, forward_batch) if self.attention_backend == "fa3":
attn_output = self.attn_mqa(
q_nope_out, k, k_nope, forward_batch, q_rope=q_pe
)
else:
q = torch.cat([q_nope_out, q_pe], dim=-1)
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
if self.use_deep_gemm_bmm: if self.use_deep_gemm_bmm:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment