"vscode:/vscode.git/clone" did not exist on "9fcdff3ed1deda3fb9a5b961d3f9aa9b7813219d"
preprocess_data.py 4.15 KB
Newer Older
1
import argparse
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
2
import json
3
import multiprocessing
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
4
import nltk
5
6
7
8
import sys
import time

import torch
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
9
10

from bert_tokenization import FullTokenizer
11
import indexed_dataset
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
12

13
class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
14

15
16
17
18
19
20
21
22
23
    _period_context_fmt = r"""
        \S*                          # some word material
        %(SentEndChars)s             # a potential sentence ending
        \s*                       #  <-- THIS is what I changed
        (?=(?P<after_tok>
            %(NonWord)s              # either other punctuation
            |
            (?P<next_tok>\S+)     #  <-- Normally you would have \s+ here
        ))"""
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
24

25
26
27
class Encoder(object):
    def __init__(self, args):
        self.args = args
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
28

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    def initializer(self):
        # Use Encoder class as a container for global data
        Encoder.tokenizer = FullTokenizer(self.args.vocab, do_lower_case=True)
        spliter = nltk.load("tokenizers/punkt/english.pickle")
        if self.args.keep_newlines:
            # this prevents punkt from eating newlines after sentences
            Encoder.spliter = nltk.tokenize.punkt.PunktSentenceTokenizer(
                train_text = spliter._params,
                lang_vars = CustomLanguageVars())
        else:
            Encoder.splitter = spliter

    def encode(self, json_line):
        text = json.loads(json_line)[self.args.json_key]
        doc_ids = []
        for sentence in Encoder.splitter.tokenize(text):
            tokens = Encoder.tokenizer.tokenize(sentence)
            ids = Encoder.tokenizer.convert_tokens_to_ids(tokens)
            doc_ids.append(ids)
        return doc_ids, len(json_line)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input', type=str, help='Path to input JSON')
    parser.add_argument('--vocab', type=str, help='Path to vocab.txt')
    parser.add_argument('--json-key', type=str, default='text',
                        help='Key to extract from json')
    parser.add_argument('--output-prefix', type=str, help='Path to binary output file without suffix')
    parser.add_argument('--workers', type=int, default=20,
                        help='Number of worker processes to launch')
    parser.add_argument('--log-interval', type=int, default=100,
                        help='Interval between progress updates')
    parser.add_argument('--keep-newlines', action='store_true',
                        help='Keep newlines between sentences.')
    parser.add_argument('--dataset-impl', type=str, default='mmap',
                        choices=['lazy', 'cached', 'mmap'])
    args = parser.parse_args()
    args.keep_empty = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
67

68
    startup_start = time.time()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
69

70
71
    print("Opening", args.input)
    fin = open(args.input, 'r', encoding='utf-8')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
72

73
74
75
76
77
78
79
    nltk.download("punkt", quiet=True)

    encoder = Encoder(args)
    tokenizer = FullTokenizer(args.vocab, do_lower_case=True)
    pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
    encoded_docs = pool.imap(encoder.encode, fin, 25)

80
81
    print(f"Vocab size: {tokenizer.vocab_size()}")

82
83
    output_bin_file = "{}.bin".format(args.output_prefix)
    output_idx_file = "{}.idx".format(args.output_prefix)
84
    builder = indexed_dataset.make_builder(output_bin_file,
85
                                      impl=args.dataset_impl,
86
                                      vocab_size=tokenizer.vocab_size())
87
88
89
90
91
92
93
94

    startup_end = time.time()
    proc_start = time.time()
    total_bytes_processed = 0
    print("Time to startup:", startup_end - startup_start)
    for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
        total_bytes_processed += bytes_processed
        for sentence in doc:
95
96
97
98
            #print(sentence)
            #print(tokenizer.convert_ids_to_tokens(sentence))
            builder.add_item(torch.IntTensor(sentence))
        builder.end_document()
99
100
101
102
        if i % args.log_interval == 0:
            current = time.time()
            elapsed = current - proc_start
            mbs = total_bytes_processed/elapsed/1024/1024
103
            print(f"Processed {i} documents",
104
105
                  f"({i/elapsed} docs/s, {mbs} MB/s).",
                  file=sys.stderr)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
106

107
    builder.finalize(output_idx_file)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
108

109
110
if __name__ == '__main__':
    main()