Commit 112588c2 authored by zhuwenwen's avatar zhuwenwen
Browse files

update lm_head of llama

parent fa5b0b39
...@@ -22,7 +22,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): ...@@ -22,7 +22,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
def __init__(self): def __init__(self):
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_lm_tn = os.environ.get('LM_TN') == '1' self.use_lm_nn = os.environ.get('LM_NN') == '1'
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
...@@ -42,7 +42,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): ...@@ -42,7 +42,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.use_llama_nn and not self.use_lm_tn: if self.use_llama_nn and self.use_lm_nn:
if bias is not None: if bias is not None:
if len(x.shape) == 2: if len(x.shape) == 2:
return torch.addmm(bias, x, layer.weight) return torch.addmm(bias, x, layer.weight)
......
...@@ -30,10 +30,10 @@ def get_model_architecture( ...@@ -30,10 +30,10 @@ def get_model_architecture(
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
else: else:
os.environ['LLAMA_NN'] = '1' os.environ['LLAMA_NN'] = '1'
if architectures == ['BloomForCausalLM'] or architectures == ['LlamaForCausalLM']: if architectures == ['BloomForCausalLM']:
os.environ['LM_TN'] = '1' os.environ['LM_NN'] = '0'
else: else:
os.environ['LM_TN'] = '0' os.environ['LM_NN'] = '1'
if os.getenv('GEMM_PAD') != '1': if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0' os.environ['GEMM_PAD'] = '0'
if os.getenv('FA_PAD') != '1': if os.getenv('FA_PAD') != '1':
...@@ -50,7 +50,7 @@ def get_model_architecture( ...@@ -50,7 +50,7 @@ def get_model_architecture(
os.environ['AWQ_PAD'] = '0' os.environ['AWQ_PAD'] = '0'
else: else:
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
os.environ['LM_TN'] = '1' os.environ['LM_NN'] = '0'
os.environ['GEMM_PAD'] = '0' os.environ['GEMM_PAD'] = '0'
os.environ['FA_PAD'] = '0' os.environ['FA_PAD'] = '0'
os.environ['AWQ_PAD'] = '0' os.environ['AWQ_PAD'] = '0'
......
...@@ -455,6 +455,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -455,6 +455,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self.quant_config=quant_config self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_lm_nn = os.environ.get('LM_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
...@@ -573,9 +574,12 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -573,9 +574,12 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
"self_attn.qkv_proj.weight", "self_attn.qkv_proj.weight",
"self_attn.o_proj.weight", "self_attn.o_proj.weight",
"mlp.gate_up_proj.weight", "mlp.gate_up_proj.weight",
"mlp.down_proj.weight", "mlp.down_proj.weight"
# "lm_head.weight"
] ]
if self.use_lm_nn:
lay_key_words.append("lm_head.weight")
combined_words = "|".join(lay_key_words) combined_words = "|".join(lay_key_words)
lay_qkv_words = ["self_attn.qkv_proj.weight"] lay_qkv_words = ["self_attn.qkv_proj.weight"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment