Unverified Commit 530ae1bd authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix weight loading for tied word embedding when TP > 1 (#2009)

parent befc6beb
......@@ -380,6 +380,12 @@ class LlamaForCausalLM(nn.Module):
]
params_dict = dict(self.named_parameters())
load_tie_word_embeddings = (
hasattr(self.config, "tie_word_embeddings")
and self.config.tie_word_embeddings
and "lm_head.weight" in params_dict
)
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name:
continue
......@@ -412,15 +418,14 @@ class LlamaForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
if (
hasattr(self.config, "tie_word_embeddings")
and self.config.tie_word_embeddings
and "lm_head.weight" in params_dict
):
if load_tie_word_embeddings and name == "model.embed_tokens.weight":
embed_tokens_weight = loaded_weight
if load_tie_word_embeddings:
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, self.model.embed_tokens.weight)
weight_loader(param, embed_tokens_weight)
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment