download_cnn_daily_mail.py 1.19 KB
Newer Older
1
# -*- coding: utf-8 -*-
2
3
4
5
6
7
8
9
10
11
import argparse
from pathlib import Path

import tensorflow_datasets as tfds


def main(input_path, reference_path, data_dir):
    cnn_ds = tfds.load("cnn_dailymail", split="test", shuffle_files=False, data_dir=data_dir)
    cnn_ds_iter = tfds.as_numpy(cnn_ds)

12
13
    test_articles_file = Path(input_path).open("w", encoding="utf-8")
    test_summaries_file = Path(reference_path).open("w", encoding="utf-8")
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

    for example in cnn_ds_iter:
        test_articles_file.write(example["article"].decode("utf-8") + "\n")
        test_articles_file.flush()
        test_summaries_file.write(example["highlights"].decode("utf-8").replace("\n", " ") + "\n")
        test_summaries_file.flush()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("input_path", type=str, help="where to save the articles input data")
    parser.add_argument(
        "reference_path", type=str, help="where to save the reference summaries",
    )
    parser.add_argument(
        "--data_dir", type=str, default="~/tensorflow_datasets", help="where to save the tensorflow datasets.",
    )
    args = parser.parse_args()
    main(args.input_path, args.reference_path, args.data_dir)