Commit f0e7d72d authored by zhuwenwen's avatar zhuwenwen
Browse files

update lm_head tn layout for awq

parent fce0353c
...@@ -519,6 +519,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -519,6 +519,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
weight.data=weight.data.reshape(ori_shape[1], -1) weight.data=weight.data.reshape(ori_shape[1], -1)
if self.quant_method == "awq": if self.quant_method == "awq":
os.environ['LM_NN'] = '0'
lay_key_words = [ lay_key_words = [
"self_attn.W_pack.qweight", "self_attn.W_pack.qweight",
"self_attn.o_proj.qweight", "self_attn.o_proj.qweight",
......
...@@ -875,4 +875,4 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -875,4 +875,4 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
pass pass
\ No newline at end of file
...@@ -507,6 +507,7 @@ class LlamaModel(nn.Module): ...@@ -507,6 +507,7 @@ class LlamaModel(nn.Module):
weight.data=weight.data.reshape(ori_shape[1], -1) weight.data=weight.data.reshape(ori_shape[1], -1)
if self.quant_method == "awq": if self.quant_method == "awq":
os.environ['LM_NN'] = '0'
lay_key_words = [ lay_key_words = [
"self_attn.qkv_proj.qweight", "self_attn.qkv_proj.qweight",
"self_attn.o_proj.qweight", "self_attn.o_proj.qweight",
......
...@@ -1132,6 +1132,7 @@ class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA): ...@@ -1132,6 +1132,7 @@ class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA):
weight.data=weight.data.reshape(ori_shape[1],-1) weight.data=weight.data.reshape(ori_shape[1],-1)
if self.quant_method == "awq": if self.quant_method == "awq":
os.environ['LM_NN'] = '0'
lay_key_words = [ lay_key_words = [
"attn.c_attn.qweight", "attn.c_attn.qweight",
"attn.c_proj.qweight", "attn.c_proj.qweight",
......
...@@ -485,6 +485,7 @@ class Qwen2Model(nn.Module): ...@@ -485,6 +485,7 @@ class Qwen2Model(nn.Module):
weight.data=weight.data.reshape(ori_shape[1],-1) weight.data=weight.data.reshape(ori_shape[1],-1)
if self.quant_method == "awq": if self.quant_method == "awq":
os.environ['LM_NN'] = '0'
lay_key_words = [ lay_key_words = [
"self_attn.qkv_proj.qweight", "self_attn.qkv_proj.qweight",
"self_attn.o_proj.qweight", "self_attn.o_proj.qweight",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment