split_gpt2_json.py 4.22 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Jared Casper's avatar
Jared Casper committed
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
17
18
19
20
"""
Takes a corpora of files (specified by `--input_files`) with json data separated
by newlines (loose json). Splits data into train.json, val.json, test.json files
under `output_dir`.

Neel Kant's avatar
Neel Kant committed
21
Note: This code has the potential to override files with the names
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
train.json, val.json, test.json in `--output_dir`.
"""
import os
import argparse
import math
import random

parser = argparse.ArgumentParser('resplit loose json data into train/val/test')
parser.add_argument('--input_files', nargs='+', required=True,
                    help='whitespace separated list of input data files')
parser.add_argument('--output_dir', required=True,
                    help='output directory where to put files')
parser.add_argument('--test_percent', type=float, nargs='+', default=[0.05, 0],
                    help='percentage of available data to use for val/test dataset')
args = parser.parse_args()

Neel Kant's avatar
Neel Kant committed
38

39
40
41
42
43
44
45
46
def get_lines(filepath):
    lines = []
    with open(filepath, 'r') as f:
        for i, l in enumerate(f.readlines()):
            l = l.strip()
            lines.append(l)
    return lines

Neel Kant's avatar
Neel Kant committed
47

48
49
50
51
52
53
54
def get_splits(lines, line_counts):
    all_lines = []
    line_idx = []
    file_mappings = []
    for i, l in enumerate(lines):
        all_lines.extend(l)
        line_idx.extend(list(range(len(l))))
Neel Kant's avatar
Neel Kant committed
55
        file_mappings.extend([i] * len(l))
56
57
58
59
60
61

    indices = list(range(len(all_lines)))
    random.shuffle(indices)
    all_lines = [all_lines[idx] for idx in indices]
    line_idx = [line_idx[idx] for idx in indices]
    file_mappings = [file_mappings[idx] for idx in indices]
Neel Kant's avatar
Neel Kant committed
62

63
64
65
66
67
68
69
70
71
72
    splits = []
    mappings = []
    start = 0
    for end in line_counts:
        end += start
        splits.append(all_lines[start:end])
        mappings.append(format_mappings(line_idx[start:end], file_mappings[start:end]))
        start = end
    return splits, mappings

Neel Kant's avatar
Neel Kant committed
73

74
75
76
def format_mappings(line_idx, file_mappings):
    lines = []
    for m, l in zip(file_mappings, line_idx):
Neel Kant's avatar
Neel Kant committed
77
        lines.append(str(m).strip() + '\t' + str(l).strip())
78
79
80
81
82
83
84
85
86
87
88
89
90
    return lines


def get_filepaths(filepaths, output_dir):
    paths = []
    train_path = 'train.json'
    dev_path = 'dev.json'
    test_path = 'test.json'
    paths.append(os.path.join(output_dir, train_path))
    paths.append(os.path.join(output_dir, dev_path))
    paths.append(os.path.join(output_dir, test_path))
    return paths

Neel Kant's avatar
Neel Kant committed
91

92
93
94
95
96
def write_files(lines, mappings, filepaths):
    for l, m, path in zip(lines, mappings, filepaths):
        write_file(l, path)
        write_mapping_file(m, path)

Neel Kant's avatar
Neel Kant committed
97

98
99
100
101
def write_file(lines, path):
    print('Writing:', path)
    with open(path, 'w') as f:
        for l in lines:
Neel Kant's avatar
Neel Kant committed
102
103
            f.write(l + '\n')

104
105

def write_mapping_file(m, path):
Neel Kant's avatar
Neel Kant committed
106
107
    path = path + '.map'
    m = [get_mapping_header()] + m
108
109
    write_file(m, path)

Neel Kant's avatar
Neel Kant committed
110

111
112
113
def get_mapping_header():
    return 'file\tline #'

Neel Kant's avatar
Neel Kant committed
114

115
116
117
118
119
120
121
122
123
if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)

lines = []

for filepath in args.input_files:
    _lines = get_lines(filepath)
    lines.append(_lines)

Neel Kant's avatar
Neel Kant committed
124
# calculate number of lines to use for each
125
126
127
line_counts = [len(l) for l in lines]
total_lines = sum(line_counts)
dev_percent = args.test_percent[0]
Neel Kant's avatar
Neel Kant committed
128
dev_lines = math.ceil(dev_percent * total_lines)
129
test_percent = 0
Neel Kant's avatar
Neel Kant committed
130
131
132
133
if len(args.test_percent) == 2:
    test_percent = args.test_percent[1]
test_lines = math.ceil(test_percent * total_lines)
train_lines = total_lines - (test_lines + dev_lines)
134
135
136
137
138
139
140
141
normed_lines = [train_lines, dev_lines, test_lines]
normed_lines = [int(l) for l in normed_lines]


splits, mappings = get_splits(lines, normed_lines)
filepaths = get_filepaths(args.input_files, args.output_dir)
print('Writing output to:', filepaths)
write_files(splits, mappings, filepaths)