Commit a528f350 authored by zhuwenwen's avatar zhuwenwen
Browse files

update fp16 model layout conversion conditions

parent f82f451f
...@@ -404,7 +404,7 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -404,7 +404,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn: if self.use_llama_nn and self.quant_method is None:
lay_key_words = [ lay_key_words = [
"self_attn.W_pack.weight", "self_attn.W_pack.weight",
"self_attn.o_proj.weight", "self_attn.o_proj.weight",
......
...@@ -404,7 +404,7 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -404,7 +404,7 @@ class ChatGLMForCausalLM(nn.Module):
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn: if self.use_llama_nn and self.quant_method is None:
lay_key_words = [ lay_key_words = [
"self_attention.query_key_value.weight", "self_attention.query_key_value.weight",
"self_attention.dense.weight", "self_attention.dense.weight",
......
...@@ -453,7 +453,7 @@ class LlamaForCausalLM(nn.Module): ...@@ -453,7 +453,7 @@ class LlamaForCausalLM(nn.Module):
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn: if self.use_llama_nn and self.quant_method is None:
lay_key_words = [ lay_key_words = [
"self_attn.qkv_proj.weight", "self_attn.qkv_proj.weight",
"self_attn.o_proj.weight", "self_attn.o_proj.weight",
......
...@@ -309,7 +309,7 @@ class QWenLMHeadModel(nn.Module): ...@@ -309,7 +309,7 @@ class QWenLMHeadModel(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn: if self.use_llama_nn and self.quant_method is None:
lay_key_words = [ lay_key_words = [
"attn.c_attn.weight", "attn.c_attn.weight",
"attn.c_proj.weight", "attn.c_proj.weight",
......
...@@ -396,7 +396,7 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -396,7 +396,7 @@ class Qwen2ForCausalLM(nn.Module):
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if self.use_llama_nn: if self.use_llama_nn and self.quant_method is None:
lay_key_words = [ lay_key_words = [
"self_attn.qkv_proj.weight", "self_attn.qkv_proj.weight",
"self_attn.o_proj.weight", "self_attn.o_proj.weight",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment