Unverified Commit d4d1fbfc authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[fsmt convert script] fairseq broke chkpt data - fixing that (#8377)

* fairseq broke chkpt data - fixing that

* style

* support older bpecodes filenames - specifically "code" in iwslt14
parent 5c766ecb
......@@ -113,7 +113,7 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder
fsmt_folder_path, checkpoint_file, data_name_or_path, archive_map=models, **kwargs
)
args = dict(vars(chkpt["args"]))
args = vars(chkpt["args"]["model"])
src_lang = args["source_lang"]
tgt_lang = args["target_lang"]
......@@ -129,7 +129,7 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder
src_vocab = rewrite_dict_keys(src_dict.indices)
src_vocab_size = len(src_vocab)
src_vocab_file = os.path.join(pytorch_dump_folder_path, "vocab-src.json")
print(f"Generating {src_vocab_file}")
print(f"Generating {src_vocab_file} of {src_vocab_size} of {src_lang} records")
with open(src_vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(src_vocab, ensure_ascii=False, indent=json_indent))
......@@ -145,13 +145,16 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder
tgt_vocab = rewrite_dict_keys(tgt_dict.indices)
tgt_vocab_size = len(tgt_vocab)
tgt_vocab_file = os.path.join(pytorch_dump_folder_path, "vocab-tgt.json")
print(f"Generating {tgt_vocab_file}")
print(f"Generating {tgt_vocab_file} of {tgt_vocab_size} of {tgt_lang} records")
with open(tgt_vocab_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tgt_vocab, ensure_ascii=False, indent=json_indent))
# merges_file (bpecodes)
merges_file = os.path.join(pytorch_dump_folder_path, VOCAB_FILES_NAMES["merges_file"])
fsmt_merges_file = os.path.join(fsmt_folder_path, "bpecodes")
for fn in ["bpecodes", "code"]: # older fairseq called the merges file "code"
fsmt_merges_file = os.path.join(fsmt_folder_path, fn)
if os.path.exists(fsmt_merges_file):
break
with open(fsmt_merges_file, encoding="utf-8") as fin:
merges = fin.read()
merges = re.sub(r" \d+$", "", merges, 0, re.M) # remove frequency number
......@@ -257,10 +260,6 @@ def convert_fsmt_checkpoint_to_pytorch(fsmt_checkpoint_path, pytorch_dump_folder
print("\nLast step is to upload the files to s3")
print(f"cd {data_root}")
print(f"transformers-cli upload {model_dir}")
print(
"Note: CDN caches files for up to 24h, so either use a local model path "
"or use `from_pretrained(mname, use_cdn=False)` to use the non-cached version."
)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment