"""Download a dataset using the nlp package and save it to the format expected by finetune.py
Format of save_dir: train.source, train.target, val.source, val.target, test.source, test.target.
Args:
src_lang: <str> source language
tgt_lang: <str> target language
dataset: <str> like wmt19 (if you don't know, try wmt19).
dataset: <str> wmt16, wmt17, etc. wmt16 is a good start as it's small. To get the full list run `import nlp; print([d.id for d in nlp.list_datasets() if "wmt" in d.id])`
save_dir: <str>, where to save the datasets, defaults to f'{dataset}-{src_lang}-{tgt_lang}'
Usage:
>>> download_wmt_dataset('en', 'ru', dataset='wmt19') # saves to wmt19_en_ru
>>> download_wmt_dataset('ro', 'en', dataset='wmt16') # saves to wmt16-ro-en
"""
try:
importnlp
except(ModuleNotFoundError,ImportError):
raiseImportError("run pip install nlp")
pair=f"{src_lang}-{tgt_lang}"
print(f"Converting {dataset}-{pair}")
ds=nlp.load_dataset(dataset,pair)
ifsave_dirisNone:
save_dir=f"{dataset}-{pair}"
save_dir=Path(save_dir)
save_dir.mkdir(exist_ok=True)
forsplitintqdm(ds.keys()):
tr_list=list(ds[split])
data=[x["translation"]forxintr_list]
src,tgt=[],[]
forexampleindata:
src.append(example[src_lang])
tgt.append(example[tgt_lang])
ifsplit=="validation":
split="val"# to save to val.source, val.target like summary datasets
src_path=save_dir.joinpath(f"{split}.source")
src_path.open("w+").write("\n".join(src))
tgt_path=save_dir.joinpath(f"{split}.target")
tgt_path.open("w+").write("\n".join(tgt))
print(f"saved dataset to {save_dir}")
forsplitinds.keys():
print(f"Splitting {split} with {ds[split].num_rows} records")
# to save to val.source, val.target like summary datasets
fn="val"ifsplit=="validation"elsesplit
src_path=save_dir.joinpath(f"{fn}.source")
tgt_path=save_dir.joinpath(f"{fn}.target")
src_fp=src_path.open("w+")
tgt_fp=tgt_path.open("w+")
# reader is the bottleneck so writing one record at a time doesn't slow things down