Unverified Commit b047b553 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

[2/2] Speed up prefill mla attention concat (#10157)

parent a0f844ed
...@@ -154,6 +154,7 @@ if _is_cuda: ...@@ -154,6 +154,7 @@ if _is_cuda:
from sgl_kernel import ( from sgl_kernel import (
awq_dequantize, awq_dequantize,
bmm_fp8, bmm_fp8,
concat_mla_k,
dsv3_fused_a_gemm, dsv3_fused_a_gemm,
dsv3_router_gemm, dsv3_router_gemm,
merge_state_v2, merge_state_v2,
...@@ -1295,8 +1296,18 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1295,8 +1296,18 @@ class DeepseekV2AttentionMLA(nn.Module):
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q[..., self.qk_nope_head_dim :] = q_pe q[..., self.qk_nope_head_dim :] = q_pe
k = torch.empty_like(q) k = torch.empty_like(q)
k[..., : self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim :] = k_pe # Temporary for DeepSeek V3/R1 only, but can generalize if needed
if (
_is_cuda
and (self.num_local_heads == 128)
and (self.qk_nope_head_dim == 128)
and (self.qk_rope_head_dim == 64)
):
concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
else:
k[..., : self.qk_nope_head_dim] = k_nope
k[..., self.qk_nope_head_dim :] = k_pe
if not _is_npu: if not _is_npu:
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1) latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment