split_json.py 3.61 KB
Newer Older
1
2
3
4
5
"""
Takes a corpora of files (specified by `--input_files`) with json data separated
by newlines (loose json). Splits data into train.json, val.json, test.json files
under `output_dir`.

Neel Kant's avatar
Neel Kant committed
6
Note: This code has the potential to override files with the names
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
train.json, val.json, test.json in `--output_dir`.
"""
import os
import argparse
import math
import random

parser = argparse.ArgumentParser('resplit loose json data into train/val/test')
parser.add_argument('--input_files', nargs='+', required=True,
                    help='whitespace separated list of input data files')
parser.add_argument('--output_dir', required=True,
                    help='output directory where to put files')
parser.add_argument('--test_percent', type=float, nargs='+', default=[0.05, 0],
                    help='percentage of available data to use for val/test dataset')
args = parser.parse_args()

Neel Kant's avatar
Neel Kant committed
23

24
25
26
27
28
29
30
31
def get_lines(filepath):
    lines = []
    with open(filepath, 'r') as f:
        for i, l in enumerate(f.readlines()):
            l = l.strip()
            lines.append(l)
    return lines

Neel Kant's avatar
Neel Kant committed
32

33
34
35
36
37
38
39
def get_splits(lines, line_counts):
    all_lines = []
    line_idx = []
    file_mappings = []
    for i, l in enumerate(lines):
        all_lines.extend(l)
        line_idx.extend(list(range(len(l))))
Neel Kant's avatar
Neel Kant committed
40
        file_mappings.extend([i] * len(l))
41
42
43
44
45
46

    indices = list(range(len(all_lines)))
    random.shuffle(indices)
    all_lines = [all_lines[idx] for idx in indices]
    line_idx = [line_idx[idx] for idx in indices]
    file_mappings = [file_mappings[idx] for idx in indices]
Neel Kant's avatar
Neel Kant committed
47

48
49
50
51
52
53
54
55
56
57
    splits = []
    mappings = []
    start = 0
    for end in line_counts:
        end += start
        splits.append(all_lines[start:end])
        mappings.append(format_mappings(line_idx[start:end], file_mappings[start:end]))
        start = end
    return splits, mappings

Neel Kant's avatar
Neel Kant committed
58

59
60
61
def format_mappings(line_idx, file_mappings):
    lines = []
    for m, l in zip(file_mappings, line_idx):
Neel Kant's avatar
Neel Kant committed
62
        lines.append(str(m).strip() + '\t' + str(l).strip())
63
64
65
66
67
68
69
70
71
72
73
74
75
    return lines


def get_filepaths(filepaths, output_dir):
    paths = []
    train_path = 'train.json'
    dev_path = 'dev.json'
    test_path = 'test.json'
    paths.append(os.path.join(output_dir, train_path))
    paths.append(os.path.join(output_dir, dev_path))
    paths.append(os.path.join(output_dir, test_path))
    return paths

Neel Kant's avatar
Neel Kant committed
76

77
78
79
80
81
def write_files(lines, mappings, filepaths):
    for l, m, path in zip(lines, mappings, filepaths):
        write_file(l, path)
        write_mapping_file(m, path)

Neel Kant's avatar
Neel Kant committed
82

83
84
85
86
def write_file(lines, path):
    print('Writing:', path)
    with open(path, 'w') as f:
        for l in lines:
Neel Kant's avatar
Neel Kant committed
87
88
            f.write(l + '\n')

89
90

def write_mapping_file(m, path):
Neel Kant's avatar
Neel Kant committed
91
92
    path = path + '.map'
    m = [get_mapping_header()] + m
93
94
    write_file(m, path)

Neel Kant's avatar
Neel Kant committed
95

96
97
98
def get_mapping_header():
    return 'file\tline #'

Neel Kant's avatar
Neel Kant committed
99

100
101
102
103
104
105
106
107
108
if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)

lines = []

for filepath in args.input_files:
    _lines = get_lines(filepath)
    lines.append(_lines)

Neel Kant's avatar
Neel Kant committed
109
# calculate number of lines to use for each
110
111
112
line_counts = [len(l) for l in lines]
total_lines = sum(line_counts)
dev_percent = args.test_percent[0]
Neel Kant's avatar
Neel Kant committed
113
dev_lines = math.ceil(dev_percent * total_lines)
114
test_percent = 0
Neel Kant's avatar
Neel Kant committed
115
116
117
118
if len(args.test_percent) == 2:
    test_percent = args.test_percent[1]
test_lines = math.ceil(test_percent * total_lines)
train_lines = total_lines - (test_lines + dev_lines)
119
120
121
122
123
124
125
126
normed_lines = [train_lines, dev_lines, test_lines]
normed_lines = [int(l) for l in normed_lines]


splits, mappings = get_splits(lines, normed_lines)
filepaths = get_filepaths(args.input_files, args.output_dir)
print('Writing output to:', filepaths)
write_files(splits, mappings, filepaths)