split_train_valid_docs.py 2.5 KB
Newer Older
1
#!/usr/bin/env python3
2
# Copyright (c) Facebook, Inc. and its affiliates.
3
#
4
5
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
"""
Split a large file into a train and valid set while respecting document
boundaries. Documents should be separated by a single empty line.
"""

import argparse
import random
import sys


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('input')
    parser.add_argument('sample_output', help='train output file')
    parser.add_argument('remainder_output', help='valid output file')
    parser.add_argument('-k', type=int, help="remainder size")
Myle Ott's avatar
Myle Ott committed
22
23
    parser.add_argument('--lines', action='store_true',
                        help='split lines instead of docs')
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    args = parser.parse_args()

    assert args.k is not None

    sample = []
    remainder = []
    num_docs = [0]

    def update_sample(doc):
        if len(sample) < args.k:
            sample.append(doc.copy())
        else:
            i = num_docs[0]
            j = random.randrange(i + 1)
            if j < args.k:
                remainder.append(sample[j])
                sample[j] = doc.copy()
            else:
                remainder.append(doc.copy())
        num_docs[0] += 1
        doc.clear()

    with open(args.input, 'r', encoding='utf-8') as h:
        doc = []
        for i, line in enumerate(h):
            if line.strip() == "":  # empty line indicates new document
                update_sample(doc)
            else:
                doc.append(line)
Myle Ott's avatar
Myle Ott committed
53
54
            if args.lines:
                update_sample(doc)
55
56
57
58
59
60
61
62
63
64
65
66
67
            if i % 1000000 == 0:
                print(i, file=sys.stderr, end="", flush=True)
            elif i % 100000 == 0:
                print(".", file=sys.stderr, end="", flush=True)
        if len(doc) > 0:
            update_sample(doc)
    print(file=sys.stderr, flush=True)

    assert len(sample) == args.k

    with open(args.sample_output, 'w', encoding='utf-8') as out:
        first = True
        for doc in sample:
Myle Ott's avatar
Myle Ott committed
68
            if not first and not args.lines:
69
70
71
72
73
74
75
76
                out.write("\n")
            first = False
            for line in doc:
                out.write(line)

    with open(args.remainder_output, 'w', encoding='utf-8') as out:
        first = True
        for doc in remainder:
Myle Ott's avatar
Myle Ott committed
77
            if not first and not args.lines:
78
79
80
81
82
83
84
85
                out.write("\n")
            first = False
            for line in doc:
                out.write(line)


if __name__ == '__main__':
    main()