"vscode:/vscode.git/clone" did not exist on "c4a63d77a12c37aa1a2c6766666f653b4d89e0c7"
Unverified Commit f414352a authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Transpose mla weight offline (#1261)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent a362340b
......@@ -417,12 +417,8 @@ class DeepseekV2AttentionMLA(nn.Module):
v_head_dim=self.kv_lora_rank,
)
kv_b_proj = self.kv_b_proj
w_kc, w_vc = kv_b_proj.weight.unflatten(
0, (-1, qk_nope_head_dim + v_head_dim)
).split([qk_nope_head_dim, v_head_dim], dim=1)
self.w_kc = w_kc
self.w_vc = w_vc
self.w_kc = None
self.w_vc = None
def forward(
self,
......@@ -464,7 +460,7 @@ class DeepseekV2AttentionMLA(nn.Module):
)
torch.bmm(
attn_output.transpose(0, 1),
self.w_vc.transpose(1, 2).contiguous(),
self.w_vc,
out=attn_bmm_output.transpose(0, 1),
)
......@@ -715,5 +711,15 @@ class DeepseekV2ForCausalLM(nn.Module):
)
weight_loader(param, loaded_weight)
if global_server_args_dict["enable_mla"]:
for layer_id in range(self.config.num_hidden_layers):
self_attn = self.model.layers[layer_id].self_attn
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
self_attn.w_kc = w_kc.contiguous()
self_attn.w_vc = w_vc.transpose(1, 2).contiguous()
del self_attn.kv_b_proj
EntryClass = DeepseekV2ForCausalLM
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment