Unverified Commit 2b981012 authored by firebook's avatar firebook Committed by GitHub
Browse files

Fix Baichuan2-7B-Chat (#1987)

parent 6ccc0bff
...@@ -366,11 +366,15 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -366,11 +366,15 @@ class BaiChuanBaseForCausalLM(nn.Module):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b class BaichuanForCausalLM(BaiChuanBaseForCausalLM
): # baichuan 13b, baichuan2 13b, baichuan2 7b
def __init__(self, def __init__(self,
config, config,
linear_method: Optional[LinearMethodBase] = None): linear_method: Optional[LinearMethodBase] = None):
if config.hidden_size == 4096: # baichuan2 7b
super().__init__(config, "ROPE", linear_method)
else: # baichuan 13b, baichuan2 13b
super().__init__(config, "ALIBI", linear_method) super().__init__(config, "ALIBI", linear_method)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment