preprocess_data.py 4.78 KB
Newer Older
1
import argparse
2
import itertools
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
3
import json
4
import multiprocessing
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
5
import nltk
6
7
8
9
import sys
import time

import torch
10
11
12
13
sys.path.insert(0, '../')
sys.path.insert(0, '../../')
from tokenizer.bert_tokenization import FullTokenizer
from data.indexed_dataset import make_builder
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
14

15
class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
16

17
18
19
20
21
22
23
24
25
    _period_context_fmt = r"""
        \S*                          # some word material
        %(SentEndChars)s             # a potential sentence ending
        \s*                       #  <-- THIS is what I changed
        (?=(?P<after_tok>
            %(NonWord)s              # either other punctuation
            |
            (?P<next_tok>\S+)     #  <-- Normally you would have \s+ here
        ))"""
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
26

27
class Encoder(object):
28
29
    splitter = None
    tokenizer = None
30
31
    def __init__(self, args):
        self.args = args
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
32

33
34
35
36
37
38
    def initializer(self):
        # Use Encoder class as a container for global data
        Encoder.tokenizer = FullTokenizer(self.args.vocab, do_lower_case=True)
        spliter = nltk.load("tokenizers/punkt/english.pickle")
        if self.args.keep_newlines:
            # this prevents punkt from eating newlines after sentences
39
            Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
40
41
42
43
44
45
46
                train_text = spliter._params,
                lang_vars = CustomLanguageVars())
        else:
            Encoder.splitter = spliter

    def encode(self, json_line):
        text = json.loads(json_line)[self.args.json_key]
47
48
        if not text:
            text = "no text"
49
50
51
52
        doc_ids = []
        for sentence in Encoder.splitter.tokenize(text):
            tokens = Encoder.tokenizer.tokenize(sentence)
            ids = Encoder.tokenizer.convert_tokens_to_ids(tokens)
53
54
            if len(ids) > 0:
                doc_ids.append(ids)
55
56
57
58
59
60
61
            else:
                print("no ids!", flush=True)
                tokens = Encoder.tokenizer.tokenize("no text")
                ids = Encoder.tokenizer.convert_tokens_to_ids(tokens)
                doc_ids.append(ids)
        if self.args.flatten and len(doc_ids) > 1:
            doc_ids = [list(itertools.chain(*doc_ids))]
62
63
64
65
66
67
        return doc_ids, len(json_line)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input', type=str, help='Path to input JSON')
    parser.add_argument('--vocab', type=str, help='Path to vocab.txt')
68
    parser.add_argument('--flatten', action='store_true', help='Path to input JSON')
69
70
71
72
73
74
75
76
77
78
79
80
81
    parser.add_argument('--json-key', type=str, default='text',
                        help='Key to extract from json')
    parser.add_argument('--output-prefix', type=str, help='Path to binary output file without suffix')
    parser.add_argument('--workers', type=int, default=20,
                        help='Number of worker processes to launch')
    parser.add_argument('--log-interval', type=int, default=100,
                        help='Interval between progress updates')
    parser.add_argument('--keep-newlines', action='store_true',
                        help='Keep newlines between sentences.')
    parser.add_argument('--dataset-impl', type=str, default='mmap',
                        choices=['lazy', 'cached', 'mmap'])
    args = parser.parse_args()
    args.keep_empty = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
82

83
    startup_start = time.time()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
84

85
86
    print("Opening", args.input)
    fin = open(args.input, 'r', encoding='utf-8')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
87

88
89
90
91
92
93
94
    nltk.download("punkt", quiet=True)

    encoder = Encoder(args)
    tokenizer = FullTokenizer(args.vocab, do_lower_case=True)
    pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
    encoded_docs = pool.imap(encoder.encode, fin, 25)

95
96
    print(f"Vocab size: {tokenizer.vocab_size()}")

97
98
    output_bin_file = "{}.bin".format(args.output_prefix)
    output_idx_file = "{}.idx".format(args.output_prefix)
99
    builder = make_builder(output_bin_file,
100
                                      impl=args.dataset_impl,
101
                                      vocab_size=tokenizer.vocab_size())
102
103
104
105
106
107
108
109

    startup_end = time.time()
    proc_start = time.time()
    total_bytes_processed = 0
    print("Time to startup:", startup_end - startup_start)
    for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
        total_bytes_processed += bytes_processed
        for sentence in doc:
110
111
112
113
            #print(sentence)
            #print(tokenizer.convert_ids_to_tokens(sentence))
            builder.add_item(torch.IntTensor(sentence))
        builder.end_document()
114
115
116
117
        if i % args.log_interval == 0:
            current = time.time()
            elapsed = current - proc_start
            mbs = total_bytes_processed/elapsed/1024/1024
118
            print(f"Processed {i} documents",
119
120
                  f"({i/elapsed} docs/s, {mbs} MB/s).",
                  file=sys.stderr)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
121

122
    builder.finalize(output_idx_file)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
123

124
125
if __name__ == '__main__':
    main()