Unverified Commit d4d1fbfc authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[fsmt convert script] fairseq broke chkpt data - fixing that (#8377)

* fairseq broke chkpt data - fixing that

* style

* support older bpecodes filenames - specifically "code" in iwslt14
parent 5c766ecb
...@@ -113,7 +113,7 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder ...@@ -113,7 +113,7 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder
fsmt_folder_path, checkpoint_file, data_name_or_path, archive_map=models, **kwargs fsmt_folder_path, checkpoint_file, data_name_or_path, archive_map=models, **kwargs
) )
args = dict(vars(chkpt["args"])) args = vars(chkpt["args"]["model"])
src_lang = args["source_lang"] src_lang = args["source_lang"]
tgt_lang = args["target_lang"] tgt_lang = args["target_lang"]
...@@ -129,7 +129,7 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder ...@@ -129,7 +129,7 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder
src_vocab = rewrite_dict_keys(src_dict.indices) src_vocab = rewrite_dict_keys(src_dict.indices)
src_vocab_size = len(src_vocab) src_vocab_size = len(src_vocab)
src_vocab_file = os.path.join(pytorch_dump_folder_path, "vocab-src.json") src_vocab_file = os.path.join(pytorch_dump_folder_path, "vocab-src.json")
print(f"Generating {src_vocab_file}") print(f"Generating {src_vocab_file} of {src_vocab_size} of {src_lang} records")
with open(src_vocab_file, "w", encoding="utf-8") as f: with open(src_vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(src_vocab, ensure_ascii=False, indent=json_indent)) f.write(json.dumps(src_vocab, ensure_ascii=False, indent=json_indent))
...@@ -145,13 +145,16 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder ...@@ -145,13 +145,16 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder
tgt_vocab = rewrite_dict_keys(tgt_dict.indices) tgt_vocab = rewrite_dict_keys(tgt_dict.indices)
tgt_vocab_size = len(tgt_vocab) tgt_vocab_size = len(tgt_vocab)
tgt_vocab_file = os.path.join(pytorch_dump_folder_path, "vocab-tgt.json") tgt_vocab_file = os.path.join(pytorch_dump_folder_path, "vocab-tgt.json")
print(f"Generating {tgt_vocab_file}") print(f"Generating {tgt_vocab_file} of {tgt_vocab_size} of {tgt_lang} records")
with open(tgt_vocab_file, "w", encoding="utf-8") as f: with open(tgt_vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tgt_vocab, ensure_ascii=False, indent=json_indent)) f.write(json.dumps(tgt_vocab, ensure_ascii=False, indent=json_indent))
# merges_file (bpecodes) # merges_file (bpecodes)
merges_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["merges_file"]) merges_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["merges_file"])
fsmt_merges_file = os.path.join(fsmt_folder_path, "bpecodes") for fn in ["bpecodes", "code"]: # older fairseq called the merges file "code"
fsmt_merges_file = os.path.join(fsmt_folder_path, fn)
if os.path.exists(fsmt_merges_file):
break
with open(fsmt_merges_file, encoding="utf-8") as fin: with open(fsmt_merges_file, encoding="utf-8") as fin:
merges = fin.read() merges = fin.read()
merges = re.sub(r" \d+$", "", merges, 0, re.M) # remove frequency number merges = re.sub(r" \d+$", "", merges, 0, re.M) # remove frequency number
...@@ -257,10 +260,6 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder ...@@ -257,10 +260,6 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder
print("\nLast step is to upload the files to s3") print("\nLast step is to upload the files to s3")
print(f"cd {data_root}") print(f"cd {data_root}")
print(f"transformers-cli upload {model_dir}") print(f"transformers-cli upload {model_dir}")
print(
"Note: CDN caches files for up to 24h, so either use a local model path "
"or use `from_pretrained(mname, use_cdn=False)` to use the non-cached version."
)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment