Unverified Commit 9eec4e93 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[M2M100] update conversion script (#17916)

parent db2644b9
......@@ -44,7 +44,7 @@ def make_linear_from_emb(emb):
def convert_fairseq_m2m100_checkpoint_from_disk(checkpoint_path):
m2m_100 = torch.load(checkpoint_path, map_location="cpu")
args = m2m_100["args"]
args = m2m_100["args"] or m2m_100["cfg"]["model"]
state_dict = m2m_100["model"]
remove_ignore_keys_(state_dict)
vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0]
......@@ -69,7 +69,7 @@ def convert_fairseq_m2m100_checkpoint_from_disk(checkpoint_path):
state_dict["shared.weight"] = state_dict["decoder.embed_tokens.weight"]
model = M2M100ForConditionalGeneration(config)
model.model.load_state_dict(state_dict)
model.model.load_state_dict(state_dict, strict=False)
model.lm_head = make_linear_from_emb(model.model.shared)
return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment