utils.py 3.99 KB
Newer Older
Zihao Ye's avatar
Zihao Ye committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import numpy as np
import torch as th
import os
from dgl.data.utils import *

_urls = {
    'wmt': 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/wmt16_en_de.tar.gz',
    'scripts': 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/transformer_scripts.zip',
}

def prepare_dataset(dataset_name):
    "download and generate datasets"
    script_dir = os.path.join('scripts')
    if not os.path.exists(script_dir):
        download(_urls['scripts'], path='scripts.zip')
        extract_archive('scripts.zip', 'scripts')

    directory = os.path.join('data', dataset_name)
    if not os.path.exists(directory):
        os.makedirs(directory)
    else:
        return
    if dataset_name == 'multi30k':
        os.system('bash scripts/prepare-multi30k.sh')
    elif dataset_name == 'wmt14':
        download(_urls['wmt'], path='wmt16_en_de.tar.gz')
        os.system('bash scripts/prepare-wmt14.sh')
    elif dataset_name == 'copy' or dataset_name == 'tiny_copy':
        train_size = 9000
        valid_size = 1000
        test_size = 1000
        char_list = [chr(i) for i in range(ord('a'), ord('z') + 1)]
        with open(os.path.join(directory, 'train.in'), 'w') as f_in,\
            open(os.path.join(directory, 'train.out'), 'w') as f_out:
            for i, l in zip(range(train_size), np.random.normal(15, 3, train_size).astype(int)):
                l = max(l, 1)
                line = ' '.join(np.random.choice(char_list, l)) + '\n'
                f_in.write(line)
                f_out.write(line)

        with open(os.path.join(directory, 'valid.in'), 'w') as f_in,\
            open(os.path.join(directory, 'valid.out'), 'w') as f_out:
            for i, l in zip(range(valid_size), np.random.normal(15, 3, valid_size).astype(int)):
                l = max(l, 1)
                line = ' '.join(np.random.choice(char_list, l)) + '\n'
                f_in.write(line)
                f_out.write(line)

        with open(os.path.join(directory, 'test.in'), 'w') as f_in,\
            open(os.path.join(directory, 'test.out'), 'w') as f_out:
            for i, l in zip(range(test_size), np.random.normal(15, 3, test_size).astype(int)):
                l = max(l, 1)
                line = ' '.join(np.random.choice(char_list, l)) + '\n'
                f_in.write(line)
                f_out.write(line)

        with open(os.path.join(directory, 'vocab.txt'), 'w') as f:
            for c in char_list:
                f.write(c + '\n')

    elif dataset_name == 'sort' or dataset_name == 'tiny_sort':
        train_size = 9000
        valid_size = 1000
        test_size = 1000
        char_list = [chr(i) for i in range(ord('a'), ord('z') + 1)]
        with open(os.path.join(directory, 'train.in'), 'w') as f_in,\
            open(os.path.join(directory, 'train.out'), 'w') as f_out:
            for i, l in zip(range(train_size), np.random.normal(15, 3, train_size).astype(int)):
                l = max(l, 1)
                seq = np.random.choice(char_list, l)
                f_in.write(' '.join(seq) + '\n')
                f_out.write(' '.join(np.sort(seq)) + '\n')

        with open(os.path.join(directory, 'valid.in'), 'w') as f_in,\
            open(os.path.join(directory, 'valid.out'), 'w') as f_out:
            for i, l in zip(range(valid_size), np.random.normal(15, 3, valid_size).astype(int)):
                l = max(l, 1)
                seq = np.random.choice(char_list, l)
                f_in.write(' '.join(seq) + '\n')
                f_out.write(' '.join(np.sort(seq)) + '\n')

        with open(os.path.join(directory, 'test.in'), 'w') as f_in,\
            open(os.path.join(directory, 'test.out'), 'w') as f_out:
            for i, l in zip(range(test_size), np.random.normal(15, 3, test_size).astype(int)):
                l = max(l, 1)
                seq = np.random.choice(char_list, l)
                f_in.write(' '.join(seq) + '\n')
                f_out.write(' '.join(np.sort(seq)) + '\n')

        with open(os.path.join(directory, 'vocab.txt'), 'w') as f:
            for c in char_list:
                f.write(c + '\n')