Commit 112588c2 authored by zhuwenwen's avatar zhuwenwen
Browse files

update lm_head of llama

parent fa5b0b39
......@@ -22,7 +22,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
def __init__(self):
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_lm_tn = os.environ.get('LM_TN') == '1'
self.use_lm_nn = os.environ.get('LM_NN') == '1'
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
......@@ -42,7 +42,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.use_llama_nn and not self.use_lm_tn:
if self.use_llama_nn and self.use_lm_nn:
if bias is not None:
if len(x.shape) == 2:
return torch.addmm(bias, x, layer.weight)
......
......@@ -30,10 +30,10 @@ def get_model_architecture(
os.environ['LLAMA_NN'] = '0'
else:
os.environ['LLAMA_NN'] = '1'
if architectures == ['BloomForCausalLM'] or architectures == ['LlamaForCausalLM']:
os.environ['LM_TN'] = '1'
if architectures == ['BloomForCausalLM']:
os.environ['LM_NN'] = '0'
else:
os.environ['LM_TN'] = '0'
os.environ['LM_NN'] = '1'
if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0'
if os.getenv('FA_PAD') != '1':
......@@ -50,7 +50,7 @@ def get_model_architecture(
os.environ['AWQ_PAD'] = '0'
else:
os.environ['LLAMA_NN'] = '0'
os.environ['LM_TN'] = '1'
os.environ['LM_NN'] = '0'
os.environ['GEMM_PAD'] = '0'
os.environ['FA_PAD'] = '0'
os.environ['AWQ_PAD'] = '0'
......
......@@ -455,6 +455,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_lm_nn = os.environ.get('LM_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
......@@ -573,9 +574,12 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight",
# "lm_head.weight"
"mlp.down_proj.weight"
]
if self.use_lm_nn:
lay_key_words.append("lm_head.weight")
combined_words = "|".join(lay_key_words)
lay_qkv_words = ["self_attn.qkv_proj.weight"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment