Unverified Commit 3f87f831 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Fuse q_a_proj and kv_a_proj (#5619)

parent ce5412b6
...@@ -443,12 +443,12 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -443,12 +443,12 @@ class DeepseekV2AttentionMLA(nn.Module):
# For tensor parallel attention # For tensor parallel attention
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear( self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size, self.hidden_size,
self.q_lora_rank, self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("q_a_proj", prefix), prefix=add_prefix("fused_qkv_a_proj_with_mqa", prefix),
) )
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear( self.q_b_proj = ColumnParallelLinear(
...@@ -470,6 +470,14 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -470,6 +470,14 @@ class DeepseekV2AttentionMLA(nn.Module):
tp_rank=attn_tp_rank, tp_rank=attn_tp_rank,
tp_size=attn_tp_size, tp_size=attn_tp_size,
) )
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
)
self.kv_b_proj = ColumnParallelLinear( self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank, self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
...@@ -490,14 +498,6 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -490,14 +498,6 @@ class DeepseekV2AttentionMLA(nn.Module):
tp_rank=attn_tp_rank, tp_rank=attn_tp_rank,
tp_size=attn_tp_size, tp_size=attn_tp_size,
) )
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
if rope_scaling: if rope_scaling:
...@@ -656,15 +656,18 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -656,15 +656,18 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0] q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
)
q = self.q_a_layernorm(q) q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else: else:
q = self.q_proj(hidden_states)[0].view( q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim -1, self.num_local_heads, self.qk_head_dim
) )
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1) latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a.contiguous()) kv_a = self.kv_a_layernorm(kv_a.contiguous())
...@@ -699,13 +702,16 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -699,13 +702,16 @@ class DeepseekV2AttentionMLA(nn.Module):
zero_allocator: BumpAllocator, zero_allocator: BumpAllocator,
) -> torch.Tensor: ) -> torch.Tensor:
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0] q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
)
q = self.q_a_layernorm(q) q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else: else:
q = self.q_proj(hidden_states)[0].view( q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim -1, self.num_local_heads, self.qk_head_dim
) )
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
if self.use_deep_gemm_bmm: if self.use_deep_gemm_bmm:
...@@ -744,7 +750,6 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -744,7 +750,6 @@ class DeepseekV2AttentionMLA(nn.Module):
q_nope_out = q_nope_out.transpose(0, 1) q_nope_out = q_nope_out.transpose(0, 1)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
k_nope = latent_cache[..., : self.kv_lora_rank] k_nope = latent_cache[..., : self.kv_lora_rank]
k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1) k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1) k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
...@@ -819,13 +824,16 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -819,13 +824,16 @@ class DeepseekV2AttentionMLA(nn.Module):
q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
) )
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0] q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
)
q = self.q_a_layernorm(q) q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else: else:
q = self.q_proj(hidden_states)[0].view( q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim -1, self.num_local_heads, self.qk_head_dim
) )
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
if self.w_kc.dtype == torch.float8_e4m3fnuz: if self.w_kc.dtype == torch.float8_e4m3fnuz:
...@@ -846,8 +854,6 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -846,8 +854,6 @@ class DeepseekV2AttentionMLA(nn.Module):
else: else:
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc) q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1) q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
v_input = latent_cache[..., : self.kv_lora_rank] v_input = latent_cache[..., : self.kv_lora_rank]
v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1) v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
k_input = latent_cache.unsqueeze(1) k_input = latent_cache.unsqueeze(1)
...@@ -1018,15 +1024,17 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -1018,15 +1024,17 @@ class DeepseekV2AttentionMLA(nn.Module):
# First do normal mha forward to get output for extended part # First do normal mha forward to get output for extended part
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0] q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
)
q = self.q_a_layernorm(q) q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else: else:
q = self.q_proj(hidden_states)[0].view( q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim -1, self.num_local_heads, self.qk_head_dim
) )
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1) latent_cache = latent_cache.unsqueeze(1)
kv_a = self.kv_a_layernorm(kv_a.contiguous()) kv_a = self.kv_a_layernorm(kv_a.contiguous())
...@@ -1668,6 +1676,12 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1668,6 +1676,12 @@ class DeepseekV2ForCausalLM(nn.Module):
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion, num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
) )
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
self.config.q_lora_rank is not None
)
cached_a_proj = {} if fuse_qkv_a_proj else None
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
# TODO(HandH1998): Modify it when nextn is supported. # TODO(HandH1998): Modify it when nextn is supported.
...@@ -1723,6 +1737,45 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1723,6 +1737,45 @@ class DeepseekV2ForCausalLM(nn.Module):
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if fuse_qkv_a_proj and (
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
):
cached_a_proj[name] = loaded_weight
q_a_proj_name = (
name
if "q_a_proj" in name
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
)
kv_a_proj_name = (
name
if "kv_a_proj_with_mqa" in name
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
)
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
if (
q_a_proj_name in cached_a_proj
and kv_a_proj_name in cached_a_proj
):
q_a_proj_weight = cached_a_proj[q_a_proj_name]
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
fused_weight = torch.cat(
[q_a_proj_weight, kv_a_proj_weight], dim=0
)
param_name = name.replace(
"q_a_proj", "fused_qkv_a_proj_with_mqa"
)
param = params_dict[param_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, fused_weight)
cached_a_proj.pop(q_a_proj_name)
cached_a_proj.pop(kv_a_proj_name)
else:
param = params_dict[name] param = params_dict[name]
weight_loader = getattr( weight_loader = getattr(
param, "weight_loader", default_weight_loader param, "weight_loader", default_weight_loader
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment