"vscode:/vscode.git/clone" did not exist on "bfa0e33294f2b1dc25e65a33be2397f989824298"
Unverified Commit e30c273b authored by xu-yfei's avatar xu-yfei Committed by GitHub
Browse files

opt flashinfer mla cat (#5822)


Co-authored-by: default avatarxuyongfei.xyf <xuyongfei.xyf@antgroup.com>
parent 0ab3f437
......@@ -339,22 +339,38 @@ class FlashInferMLAAttnBackend(AttentionBackend):
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
):
cache_loc = forward_batch.out_cache_loc
logits_soft_cap = layer.logit_cap
prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
# Save kv cache
if save_kv_cache and k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
if k_rope is not None:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer, cache_loc, k, k_rope
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
if q_rope is not None:
q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
if self.forward_metadata.use_ragged:
# ragged prefill
if q_rope is not None:
q = torch.cat([q, q_rope], dim=-1)
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
if k_rope is not None:
k = torch.cat([k, k_rope], dim=-1)
o = self.prefill_wrapper_ragged.forward(
qall,
k.view(-1, layer.tp_k_head_num, layer.head_dim),
......@@ -365,11 +381,19 @@ class FlashInferMLAAttnBackend(AttentionBackend):
)
else:
# mla paged prefill
if q_rope is None:
qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
q, q_rope = (
qall[:, :, : layer.v_head_dim],
qall[:, :, layer.v_head_dim :],
)
o = q.new_empty(q.shape)
o = prefill_wrapper_paged.run(
qall[:, :, : layer.v_head_dim],
qall[:, :, layer.v_head_dim :],
q,
q_rope,
k_buf[:, :, : layer.v_head_dim],
k_buf[:, :, layer.v_head_dim :],
out=o,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
......@@ -382,6 +406,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
):
decode_wrapper = self.forward_metadata.decode_wrapper
cache_loc = forward_batch.out_cache_loc
......@@ -389,23 +416,42 @@ class FlashInferMLAAttnBackend(AttentionBackend):
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
if k_rope is not None:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
cache_loc,
k,
k_rope,
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
# Reshape inputs
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
if q_rope is not None:
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
q_rope = q_rope.view(
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
)
else:
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = reshaped_q[:, :, : layer.v_head_dim]
q_rope = reshaped_q[:, :, layer.v_head_dim :]
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
o = q_nope.new_empty(q_nope.shape)
# Direct call to run without the wrapper
o = decode_wrapper.run(
reshaped_q[:, :, : layer.v_head_dim],
reshaped_q[:, :, layer.v_head_dim :],
q_nope,
q_rope,
k_buffer[:, :, : layer.v_head_dim],
k_buffer[:, :, layer.v_head_dim :],
out=o,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
......
......@@ -777,7 +777,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = q_nope_out.transpose(0, 1)
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
if self.attention_backend == "fa3":
if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
attn_output = self.attn_mqa(
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment