Unverified Commit b047b553 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

[2/2] Speed up prefill mla attention concat (#10157)

parent a0f844ed
......@@ -154,6 +154,7 @@ if _is_cuda:
from sgl_kernel import (
awq_dequantize,
bmm_fp8,
concat_mla_k,
dsv3_fused_a_gemm,
dsv3_router_gemm,
merge_state_v2,
......@@ -1295,8 +1296,18 @@ class DeepseekV2AttentionMLA(nn.Module):
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q[..., self.qk_nope_head_dim :] = q_pe
k = torch.empty_like(q)
k[..., : self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim :] = k_pe
# Temporary for DeepSeek V3/R1 only, but can generalize if needed
if (
_is_cuda
and (self.num_local_heads == 128)
and (self.qk_nope_head_dim == 128)
and (self.qk_rope_head_dim == 64)
):
concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
else:
k[..., : self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim :] = k_pe
if not _is_npu:
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment