utils.py 4.23 KB
Newer Older
1
2
import os

Zihao Ye's avatar
Zihao Ye committed
3
4
import numpy as np
import torch as th
5

Zihao Ye's avatar
Zihao Ye committed
6
7
8
from dgl.data.utils import *

_urls = {
9
10
    "wmt": "https://data.dgl.ai/dataset/wmt14bpe_de_en.zip",
    "scripts": "https://data.dgl.ai/dataset/transformer_scripts.zip",
Zihao Ye's avatar
Zihao Ye committed
11
12
}

13

Zihao Ye's avatar
Zihao Ye committed
14
15
def prepare_dataset(dataset_name):
    "download and generate datasets"
16
    script_dir = os.path.join("scripts")
Zihao Ye's avatar
Zihao Ye committed
17
    if not os.path.exists(script_dir):
18
19
        download(_urls["scripts"], path="scripts.zip")
        extract_archive("scripts.zip", "scripts")
Zihao Ye's avatar
Zihao Ye committed
20

21
    directory = os.path.join("data", dataset_name)
Zihao Ye's avatar
Zihao Ye committed
22
23
24
25
    if not os.path.exists(directory):
        os.makedirs(directory)
    else:
        return
26
27
28
29
30
31
    if dataset_name == "multi30k":
        os.system("bash scripts/prepare-multi30k.sh")
    elif dataset_name == "wmt14":
        download(_urls["wmt"], path="wmt14.zip")
        os.system("bash scripts/prepare-wmt14.sh")
    elif dataset_name == "copy" or dataset_name == "tiny_copy":
Zihao Ye's avatar
Zihao Ye committed
32
33
34
        train_size = 9000
        valid_size = 1000
        test_size = 1000
35
36
37
38
39
40
41
42
        char_list = [chr(i) for i in range(ord("a"), ord("z") + 1)]
        with open(os.path.join(directory, "train.in"), "w") as f_in, open(
            os.path.join(directory, "train.out"), "w"
        ) as f_out:
            for i, l in zip(
                range(train_size),
                np.random.normal(15, 3, train_size).astype(int),
            ):
Zihao Ye's avatar
Zihao Ye committed
43
                l = max(l, 1)
44
                line = " ".join(np.random.choice(char_list, l)) + "\n"
Zihao Ye's avatar
Zihao Ye committed
45
46
47
                f_in.write(line)
                f_out.write(line)

48
49
50
51
52
53
54
        with open(os.path.join(directory, "valid.in"), "w") as f_in, open(
            os.path.join(directory, "valid.out"), "w"
        ) as f_out:
            for i, l in zip(
                range(valid_size),
                np.random.normal(15, 3, valid_size).astype(int),
            ):
Zihao Ye's avatar
Zihao Ye committed
55
                l = max(l, 1)
56
                line = " ".join(np.random.choice(char_list, l)) + "\n"
Zihao Ye's avatar
Zihao Ye committed
57
58
59
                f_in.write(line)
                f_out.write(line)

60
61
62
63
64
65
        with open(os.path.join(directory, "test.in"), "w") as f_in, open(
            os.path.join(directory, "test.out"), "w"
        ) as f_out:
            for i, l in zip(
                range(test_size), np.random.normal(15, 3, test_size).astype(int)
            ):
Zihao Ye's avatar
Zihao Ye committed
66
                l = max(l, 1)
67
                line = " ".join(np.random.choice(char_list, l)) + "\n"
Zihao Ye's avatar
Zihao Ye committed
68
69
70
                f_in.write(line)
                f_out.write(line)

71
        with open(os.path.join(directory, "vocab.txt"), "w") as f:
Zihao Ye's avatar
Zihao Ye committed
72
            for c in char_list:
73
                f.write(c + "\n")
Zihao Ye's avatar
Zihao Ye committed
74

75
    elif dataset_name == "sort" or dataset_name == "tiny_sort":
Zihao Ye's avatar
Zihao Ye committed
76
77
78
        train_size = 9000
        valid_size = 1000
        test_size = 1000
79
80
81
82
83
84
85
86
        char_list = [chr(i) for i in range(ord("a"), ord("z") + 1)]
        with open(os.path.join(directory, "train.in"), "w") as f_in, open(
            os.path.join(directory, "train.out"), "w"
        ) as f_out:
            for i, l in zip(
                range(train_size),
                np.random.normal(15, 3, train_size).astype(int),
            ):
Zihao Ye's avatar
Zihao Ye committed
87
88
                l = max(l, 1)
                seq = np.random.choice(char_list, l)
89
90
                f_in.write(" ".join(seq) + "\n")
                f_out.write(" ".join(np.sort(seq)) + "\n")
Zihao Ye's avatar
Zihao Ye committed
91

92
93
94
95
96
97
98
        with open(os.path.join(directory, "valid.in"), "w") as f_in, open(
            os.path.join(directory, "valid.out"), "w"
        ) as f_out:
            for i, l in zip(
                range(valid_size),
                np.random.normal(15, 3, valid_size).astype(int),
            ):
Zihao Ye's avatar
Zihao Ye committed
99
100
                l = max(l, 1)
                seq = np.random.choice(char_list, l)
101
102
                f_in.write(" ".join(seq) + "\n")
                f_out.write(" ".join(np.sort(seq)) + "\n")
Zihao Ye's avatar
Zihao Ye committed
103

104
105
106
107
108
109
        with open(os.path.join(directory, "test.in"), "w") as f_in, open(
            os.path.join(directory, "test.out"), "w"
        ) as f_out:
            for i, l in zip(
                range(test_size), np.random.normal(15, 3, test_size).astype(int)
            ):
Zihao Ye's avatar
Zihao Ye committed
110
111
                l = max(l, 1)
                seq = np.random.choice(char_list, l)
112
113
                f_in.write(" ".join(seq) + "\n")
                f_out.write(" ".join(np.sort(seq)) + "\n")
Zihao Ye's avatar
Zihao Ye committed
114

115
        with open(os.path.join(directory, "vocab.txt"), "w") as f:
Zihao Ye's avatar
Zihao Ye committed
116
            for c in char_list:
117
                f.write(c + "\n")