Commit 1428c17d authored by zhuwenwen's avatar zhuwenwen
Browse files

auto convert lm_head layout of llama

parent 85def94c
...@@ -22,7 +22,6 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): ...@@ -22,7 +22,6 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
def __init__(self): def __init__(self):
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_lm_nn = os.environ.get('LM_NN') == '1'
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
...@@ -42,7 +41,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): ...@@ -42,7 +41,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.use_llama_nn and self.use_lm_nn: if self.use_llama_nn and os.environ['LM_NN'] == '1':
if bias is not None: if bias is not None:
if len(x.shape) == 2: if len(x.shape) == 2:
return torch.addmm(bias, x, layer.weight) return torch.addmm(bias, x, layer.weight)
......
...@@ -455,7 +455,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -455,7 +455,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self.quant_config=quant_config self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_lm_nn = os.environ.get('LM_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
...@@ -574,8 +573,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -574,8 +573,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
"self_attn.qkv_proj.weight", "self_attn.qkv_proj.weight",
"self_attn.o_proj.weight", "self_attn.o_proj.weight",
"mlp.gate_up_proj.weight", "mlp.gate_up_proj.weight",
"mlp.down_proj.weight" "mlp.down_proj.weight",
# "lm_head.weight" "lm_head.weight"
] ]
if self.use_lm_nn: if self.use_lm_nn:
...@@ -587,6 +586,10 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -587,6 +586,10 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
qkv_words = "|".join(lay_qkv_words) qkv_words = "|".join(lay_qkv_words)
for layername, weight in params_dict.items(): for layername, weight in params_dict.items():
if "lm_head.weight" in layername:
os.environ['LM_NN'] = '1'
else:
os.environ['LM_NN'] = '0'
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches:
if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]): if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment