Commit 277271f8 authored by zhuwenwen's avatar zhuwenwen
Browse files

update baichuan.py

parent 746d9b40
...@@ -399,6 +399,12 @@ class BaiChuanModel(nn.Module): ...@@ -399,6 +399,12 @@ class BaiChuanModel(nn.Module):
for layername in loaded_params: for layername in loaded_params:
weight = params_dict[layername] weight = params_dict[layername]
if "lm_head.weight" in layername and weight.shape[1] >= 4096:
lay_key_words.append("lm_head.weight")
combined_words = "|".join(lay_key_words)
os.environ['LM_NN'] = '1'
else:
os.environ['LM_NN'] = '0'
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches:
# if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]): # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
...@@ -579,4 +585,4 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): ...@@ -579,4 +585,4 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, super().__init__(vllm_config=vllm_config,
prefix=prefix, prefix=prefix,
position_embedding="ROPE") position_embedding="ROPE")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment