Unverified Commit 2838d6b3 authored by Domen Vreš's avatar Domen Vreš Committed by GitHub
Browse files

[Bugfix] Weight loading fix for OPT model (#9042)


Co-authored-by: default avatardvres <dvres@fri.uni-lj.si>
parent 91add85e
...@@ -353,7 +353,7 @@ class OPTForCausalLM(nn.Module): ...@@ -353,7 +353,7 @@ class OPTForCausalLM(nn.Module):
] ]
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "lm_head.weight" in name: if "lm_head.weight" in name and self.config.tie_word_embeddings:
continue continue
if name.startswith("decoder."): if name.startswith("decoder."):
name = "model." + name name = "model." + name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment