Unverified Commit 7cc2c9c6 authored by Beomseok Lee's avatar Beomseok Lee Committed by GitHub
Browse files

Fix bugs of s2t fairseq model converting (#15593)

* Fix bugs for argument typo and positional embedding weight loading

* Reflect code review suggestion to cover different missing keys cases
parent 7865f4d0
......@@ -94,7 +94,17 @@ def convert_fairseq_s2t_checkpoint_to_tfms(checkpoint_path, pytorch_dump_folder_
)
model = Speech2TextForConditionalGeneration(config)
model.model.load_state_dict(state_dict)
missing, unexpected = model.model.load_state_dict(state_dict, strict=False)
if len(missing) > 0 and not set(missing) <= set(
[
"encoder.embed_positions.weights",
"decoder.embed_positions.weights",
]
):
raise ValueError(
f"Only `encoder.embed_positions.weights` and `decoder.embed_positions.weights` are allowed to be missing, but all the following weights are missing {missing}"
)
if tie_embeds:
model.lm_head = make_linear_from_emb(model.model.decoder.embed_tokens)
else:
......@@ -106,7 +116,7 @@ def convert_fairseq_s2t_checkpoint_to_tfms(checkpoint_path, pytorch_dump_folder_
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("fairseq_path", type=str, help="Path to the fairseq model (.pt) file.")
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--fairseq_path", type=str, help="Path to the fairseq model (.pt) file.")
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
args = parser.parse_args()
convert_fairseq_s2t_checkpoint_to_tfms(args.fairseq_path, args.pytorch_dump_folder_path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment