Commit 99426767 authored by zhuwenwen's avatar zhuwenwen
Browse files

support transpose lm_head weight to nn layout

parent 1fabf3e1
...@@ -127,8 +127,8 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -127,8 +127,8 @@ class UnquantizedLinearMethod(LinearMethodBase):
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.use_llama_nn: if self.use_llama_nn:
if gemm_bank_conf(weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1': if gemm_bank_conf(layer.weight.shape[1] - 32) and os.environ['GEMM_PAD'] == '1':
weight = weight[:,:-32] layer.weight = layer.weight[:,:-32]
if bias is not None: if bias is not None:
if len(x.shape) == 2: if len(x.shape) == 2:
......
...@@ -418,7 +418,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA): ...@@ -418,7 +418,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
"self_attn.W_pack.weight", "self_attn.W_pack.weight",
"self_attn.o_proj.weight", "self_attn.o_proj.weight",
"mlp.gate_up_proj.weight", "mlp.gate_up_proj.weight",
"mlp.down_proj.weight" "mlp.down_proj.weight",
"lm_head.weight"
] ]
combined_words = "|".join(lay_key_words) combined_words = "|".join(lay_key_words)
......
...@@ -416,7 +416,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA): ...@@ -416,7 +416,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
"self_attention.query_key_value.weight", "self_attention.query_key_value.weight",
"self_attention.dense.weight", "self_attention.dense.weight",
"mlp.dense_h_to_4h.weight", "mlp.dense_h_to_4h.weight",
"mlp.dense_4h_to_h.weight" "mlp.dense_4h_to_h.weight",
"lm_head.weight"
] ]
combined_words = "|".join(lay_key_words) combined_words = "|".join(lay_key_words)
......
...@@ -496,7 +496,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -496,7 +496,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
"self_attn.qkv_proj.weight", "self_attn.qkv_proj.weight",
"self_attn.o_proj.weight", "self_attn.o_proj.weight",
"mlp.gate_up_proj.weight", "mlp.gate_up_proj.weight",
"mlp.down_proj.weight" "mlp.down_proj.weight",
"lm_head.weight"
] ]
combined_words = "|".join(lay_key_words) combined_words = "|".join(lay_key_words)
......
...@@ -321,7 +321,8 @@ class QWenLMHeadModel(nn.Module): ...@@ -321,7 +321,8 @@ class QWenLMHeadModel(nn.Module):
"attn.c_attn.weight", "attn.c_attn.weight",
"attn.c_proj.weight", "attn.c_proj.weight",
"mlp.gate_up_proj.weight", "mlp.gate_up_proj.weight",
"mlp.c_proj.weight" "mlp.c_proj.weight",
"lm_head.weight"
] ]
combined_words = "|".join(lay_key_words) combined_words = "|".join(lay_key_words)
......
...@@ -414,7 +414,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -414,7 +414,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
"self_attn.qkv_proj.weight", "self_attn.qkv_proj.weight",
"self_attn.o_proj.weight", "self_attn.o_proj.weight",
"mlp.gate_up_proj.weight", "mlp.gate_up_proj.weight",
"mlp.down_proj.weight" "mlp.down_proj.weight",
"lm_head.weight"
] ]
combined_words = "|".join(lay_key_words) combined_words = "|".join(lay_key_words)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment