Commit 277271f8 authored by zhuwenwen's avatar zhuwenwen
Browse files

update baichuan.py

parent 746d9b40
......@@ -399,6 +399,12 @@ class BaiChuanModel(nn.Module):
for layername in loaded_params:
weight = params_dict[layername]
if "lm_head.weight" in layername and weight.shape[1] >= 4096:
lay_key_words.append("lm_head.weight")
combined_words = "|".join(lay_key_words)
os.environ['LM_NN'] = '1'
else:
os.environ['LM_NN'] = '0'
matches = re.findall(combined_words, layername)
if matches:
# if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
......@@ -579,4 +585,4 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config,
prefix=prefix,
position_embedding="ROPE")
position_embedding="ROPE")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment