download_wmt.py 1.98 KB
Newer Older
1
2
#!/usr/bin/env python

3
4
5
6
7
8
from pathlib import Path

import fire
from tqdm import tqdm


9
def download_wmt_dataset(src_lang="ro", tgt_lang="en", dataset="wmt16", save_dir=None) -> None:
10
    """Download a dataset using the datasets package and save it to the format expected by finetune.py
11
12
13
14
15
    Format of save_dir: train.source, train.target, val.source, val.target, test.source, test.target.

    Args:
        src_lang: <str> source language
        tgt_lang: <str> target language
16
        dataset: <str> wmt16, wmt17, etc. wmt16 is a good start as it's small. To get the full list run `import datasets; print([d.id for d in datasets.list_datasets() if "wmt" in d.id])`
17
18
19
        save_dir: <str>, where to save the datasets, defaults to f'{dataset}-{src_lang}-{tgt_lang}'

    Usage:
20
        >>> download_wmt_dataset('ro', 'en', dataset='wmt16') # saves to wmt16-ro-en
21
22
    """
    try:
23
        import datasets
24
    except (ModuleNotFoundError, ImportError):
25
        raise ImportError("run pip install datasets")
26
    pair = f"{src_lang}-{tgt_lang}"
27
    print(f"Converting {dataset}-{pair}")
28
    ds = datasets.load_dataset(dataset, pair)
29
30
31
32
33
    if save_dir is None:
        save_dir = f"{dataset}-{pair}"
    save_dir = Path(save_dir)
    save_dir.mkdir(exist_ok=True)

34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    for split in ds.keys():
        print(f"Splitting {split} with {ds[split].num_rows} records")

        # to save to val.source, val.target like summary datasets
        fn = "val" if split == "validation" else split
        src_path = save_dir.joinpath(f"{fn}.source")
        tgt_path = save_dir.joinpath(f"{fn}.target")
        src_fp = src_path.open("w+")
        tgt_fp = tgt_path.open("w+")

        # reader is the bottleneck so writing one record at a time doesn't slow things down
        for x in tqdm(ds[split]):
            ex = x["translation"]
            src_fp.write(ex[src_lang] + "\n")
            tgt_fp.write(ex[tgt_lang] + "\n")

    print(f"Saved {dataset} dataset to {save_dir}")
51
52
53
54


if __name__ == "__main__":
    fire.Fire(download_wmt_dataset)